You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@doris.apache.org by mo...@apache.org on 2022/07/05 09:43:07 UTC

[doris] branch master updated: [feature-wip](unique-key-merge-on-write)port IntervalTree from kudu (#10511)

This is an automated email from the ASF dual-hosted git repository.

morningman pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris.git


The following commit(s) were added to refs/heads/master by this push:
     new 86502b014d [feature-wip](unique-key-merge-on-write)port IntervalTree from kudu (#10511)
86502b014d is described below

commit 86502b014d8412dd84b8a5a390114ceda9feba85
Author: zhannngchen <48...@users.noreply.github.com>
AuthorDate: Tue Jul 5 17:43:01 2022 +0800

    [feature-wip](unique-key-merge-on-write)port IntervalTree from kudu (#10511)
    
    See the DISP-18:https://cwiki.apache.org/confluence/display/DORIS/DSIP-018%3A+Support+Merge-On-Write+implementation+for+UNIQUE+KEY+data+model
    This patch is for step 3.1 in scheduling.
---
 be/src/util/interval_tree-inl.h     | 440 ++++++++++++++++++++++++++++++++++++
 be/src/util/interval_tree.h         | 159 +++++++++++++
 be/test/CMakeLists.txt              |   1 +
 be/test/util/interval_tree_test.cpp | 392 ++++++++++++++++++++++++++++++++
 4 files changed, 992 insertions(+)

diff --git a/be/src/util/interval_tree-inl.h b/be/src/util/interval_tree-inl.h
new file mode 100644
index 0000000000..d322d260d5
--- /dev/null
+++ b/be/src/util/interval_tree-inl.h
@@ -0,0 +1,440 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+//
+// This file is copied from
+// https://github.com/apache/kudu/blob/master/src/kudu/util/interval_tree-inl.h
+// and modified by Doris
+//
+
+#pragma once
+
+#include <algorithm>
+#include <vector>
+
+#include "util/interval_tree.h"
+
+namespace doris {
+
+template <class Traits>
+IntervalTree<Traits>::IntervalTree(const IntervalVector& intervals) : root_(NULL) {
+    if (!intervals.empty()) {
+        root_ = CreateNode(intervals);
+    }
+}
+
+template <class Traits>
+IntervalTree<Traits>::~IntervalTree() {
+    delete root_;
+}
+
+template <class Traits>
+template <class QueryPointType>
+void IntervalTree<Traits>::FindContainingPoint(const QueryPointType& query,
+                                               IntervalVector* results) const {
+    if (root_) {
+        root_->FindContainingPoint(query, results);
+    }
+}
+
+template <class Traits>
+template <class Callback, class QueryContainer>
+void IntervalTree<Traits>::ForEachIntervalContainingPoints(const QueryContainer& queries,
+                                                           const Callback& cb) const {
+    if (root_) {
+        root_->ForEachIntervalContainingPoints(queries.begin(), queries.end(), cb);
+    }
+}
+
+template <class Traits>
+template <class QueryPointType>
+void IntervalTree<Traits>::FindIntersectingInterval(const QueryPointType& lower_bound,
+                                                    const QueryPointType& upper_bound,
+                                                    IntervalVector* results) const {
+    if (root_) {
+        root_->FindIntersectingInterval(lower_bound, upper_bound, results);
+    }
+}
+
+template <class Traits>
+static bool LessThan(const typename Traits::point_type& a, const typename Traits::point_type& b) {
+    return Traits::compare(a, b) < 0;
+}
+
+// Select a split point which attempts to evenly divide 'in' into three groups:
+//  (a) those that are fully left of the split point
+//  (b) those that overlap the split point.
+//  (c) those that are fully right of the split point
+// These three groups are stored in the output parameters '*left', '*overlapping',
+// and '*right', respectively. The selected split point is stored in *split_point.
+//
+// For example, the input interval set:
+//
+//   |------1-------|         |-----2-----|
+//       |--3--|    |---4--|    |----5----|
+//                     |
+// Resulting split:    | Partition point
+//                     |
+//
+// *left: intervals 1 and 3
+// *overlapping: interval 4
+// *right: intervals 2 and 5
+template <class Traits>
+void IntervalTree<Traits>::Partition(const IntervalVector& in, point_type* split_point,
+                                     IntervalVector* left, IntervalVector* overlapping,
+                                     IntervalVector* right) {
+    CHECK(!in.empty());
+
+    // Pick a split point which is the median of all of the interval boundaries.
+    std::vector<point_type> endpoints;
+    endpoints.reserve(in.size() * 2);
+    for (const interval_type& interval : in) {
+        endpoints.push_back(Traits::get_left(interval));
+        endpoints.push_back(Traits::get_right(interval));
+    }
+    std::sort(endpoints.begin(), endpoints.end(), LessThan<Traits>);
+    *split_point = endpoints[endpoints.size() / 2];
+
+    // Partition into the groups based on the determined split point.
+    for (const interval_type& interval : in) {
+        if (Traits::compare(Traits::get_right(interval), *split_point) < 0) {
+            //                 | split point
+            // |------------|  |
+            //    interval
+            left->push_back(interval);
+        } else if (Traits::compare(Traits::get_left(interval), *split_point) > 0) {
+            //                 | split point
+            //                 |    |------------|
+            //                         interval
+            right->push_back(interval);
+        } else {
+            //                 | split point
+            //                 |
+            //          |------------|
+            //             interval
+            overlapping->push_back(interval);
+        }
+    }
+}
+
+template <class Traits>
+typename IntervalTree<Traits>::node_type* IntervalTree<Traits>::CreateNode(
+        const IntervalVector& intervals) {
+    IntervalVector left, right, overlap;
+    point_type split_point;
+
+    // First partition the input intervals and select a split point
+    Partition(intervals, &split_point, &left, &overlap, &right);
+
+    // Recursively subdivide the intervals which are fully left or fully
+    // right of the split point into subtree nodes.
+    node_type* left_node = !left.empty() ? CreateNode(left) : NULL;
+    node_type* right_node = !right.empty() ? CreateNode(right) : NULL;
+
+    return new node_type(split_point, left_node, overlap, right_node);
+}
+
+namespace interval_tree_internal {
+
+// Node in the interval tree.
+template <typename Traits>
+class ITNode {
+private:
+    // Import types.
+    typedef std::vector<typename Traits::interval_type> IntervalVector;
+    typedef typename Traits::interval_type interval_type;
+    typedef typename Traits::point_type point_type;
+
+public:
+    ITNode(point_type split_point, ITNode<Traits>* left, const IntervalVector& overlap,
+           ITNode<Traits>* right);
+    ~ITNode();
+
+    // See IntervalTree::FindContainingPoint(...)
+    template <class QueryPointType>
+    void FindContainingPoint(const QueryPointType& query, IntervalVector* results) const;
+
+    // See IntervalTree::ForEachIntervalContainingPoints().
+    // We use iterators here since as recursion progresses down the tree, we
+    // process sub-sequences of the original set of query points.
+    template <class Callback, class ItType>
+    void ForEachIntervalContainingPoints(ItType begin_queries, ItType end_queries,
+                                         const Callback& cb) const;
+
+    // See IntervalTree::FindIntersectingInterval(...)
+    template <class QueryPointType>
+    void FindIntersectingInterval(const QueryPointType& lower_bound,
+                                  const QueryPointType& upper_bound, IntervalVector* results) const;
+
+private:
+    // Comparators for sorting lists of intervals.
+    static bool SortByAscLeft(const interval_type& a, const interval_type& b);
+    static bool SortByDescRight(const interval_type& a, const interval_type& b);
+
+    // Partition point of this node.
+    point_type split_point_;
+
+    // Those nodes that overlap with split_point_, in ascending order by their left side.
+    IntervalVector overlapping_by_asc_left_;
+
+    // Those nodes that overlap with split_point_, in descending order by their right side.
+    IntervalVector overlapping_by_desc_right_;
+
+    // Tree node for intervals fully left of split_point_, or NULL.
+    ITNode* left_;
+
+    // Tree node for intervals fully right of split_point_, or NULL.
+    ITNode* right_;
+
+    DISALLOW_COPY_AND_ASSIGN(ITNode);
+};
+
+template <class Traits>
+bool ITNode<Traits>::SortByAscLeft(const interval_type& a, const interval_type& b) {
+    return Traits::compare(Traits::get_left(a), Traits::get_left(b)) < 0;
+}
+
+template <class Traits>
+bool ITNode<Traits>::SortByDescRight(const interval_type& a, const interval_type& b) {
+    return Traits::compare(Traits::get_right(a), Traits::get_right(b)) > 0;
+}
+
+template <class Traits>
+ITNode<Traits>::ITNode(typename Traits::point_type split_point, ITNode<Traits>* left,
+                       const IntervalVector& overlap, ITNode<Traits>* right)
+        : split_point_(std::move(split_point)), left_(left), right_(right) {
+    // Store two copies of the set of intervals which overlap the split point:
+    // 1) Sorted by ascending left boundary
+    overlapping_by_asc_left_.assign(overlap.begin(), overlap.end());
+    std::sort(overlapping_by_asc_left_.begin(), overlapping_by_asc_left_.end(), SortByAscLeft);
+    // 2) Sorted by descending right boundary
+    overlapping_by_desc_right_.assign(overlap.begin(), overlap.end());
+    std::sort(overlapping_by_desc_right_.begin(), overlapping_by_desc_right_.end(),
+              SortByDescRight);
+}
+
+template <class Traits>
+ITNode<Traits>::~ITNode() {
+    if (left_) delete left_;
+    if (right_) delete right_;
+}
+
+template <class Traits>
+template <class Callback, class ItType>
+void ITNode<Traits>::ForEachIntervalContainingPoints(ItType begin_queries, ItType end_queries,
+                                                     const Callback& cb) const {
+    if (begin_queries == end_queries) return;
+
+    typedef decltype(*begin_queries) QueryPointType;
+    const auto& partitioner = [&](const QueryPointType& query_point) {
+        return Traits::compare(query_point, split_point_) < 0;
+    };
+
+    // Partition the query points into those less than the split_point_ and those greater
+    // than or equal to the split_point_. Because the input queries are already sorted, we
+    // can use 'std::partition_point' instead of 'std::partition'.
+    //
+    // The resulting 'partition_point' is the first query point in the second group.
+    //
+    // Complexity: O(log(number of query points))
+    DCHECK(std::is_partitioned(begin_queries, end_queries, partitioner));
+    auto partition_point = std::partition_point(begin_queries, end_queries, partitioner);
+
+    // Recurse left: any query points left of the split point may intersect
+    // with non-overlapping intervals fully-left of our split point.
+    if (left_ != NULL) {
+        left_->ForEachIntervalContainingPoints(begin_queries, partition_point, cb);
+    }
+
+    // Handle the query points < split_point                  /
+    //                                                        /
+    //      split_point_                                      /
+    //         |                                              /
+    //   [------]         \                                   /
+    //     [-------]       | overlapping_by_asc_left_         /
+    //       [--------]   /                                   /
+    // Q   Q      Q                                           /
+    // ^   ^      \___ not handled (right of split_point_)    /
+    // |   |                                                  /
+    // \___\___ these points will be handled here             /
+    //
+
+    // Lower bound of query points still relevant.
+    auto rem_queries = begin_queries;
+    for (const interval_type& interval : overlapping_by_asc_left_) {
+        const auto& interval_left = Traits::get_left(interval);
+        // Find those query points which are right of the left side of the interval.
+        // 'first_match' here is the first query point >= interval_left.
+        // Complexity: O(log(num_queries))
+        //
+        // TODO(todd): The non-batched implementation is O(log(num_intervals) * num_queries)
+        // whereas this loop ends up O(num_intervals * log(num_queries)). So, for
+        // small numbers of queries this is not the fastest way to structure these loops.
+        auto first_match = std::partition_point(
+                rem_queries, partition_point, [&](const QueryPointType& query_point) {
+                    return Traits::compare(query_point, interval_left) < 0;
+                });
+        for (auto it = first_match; it != partition_point; ++it) {
+            cb(*it, interval);
+        }
+        // Since the intervals are sorted in ascending-left order, we can start
+        // the search for the next interval at the first match in this interval.
+        // (any query point which was left of the current interval will also be left
+        // of all future intervals).
+        rem_queries = std::move(first_match);
+    }
+
+    // Handle the query points >= split_point                        /
+    //                                                               /
+    //    split_point_                                               /
+    //       |                                                       /
+    //     [--------]   \                                            /
+    //   [-------]       | overlapping_by_desc_right_                /
+    // [------]         /                                            /
+    //   Q   Q      Q                                                /
+    //   |    \______\___ these points will be handled here          /
+    //   |                                                           /
+    //   \___ not handled (left of split_point_)                     /
+
+    // Upper bound of query points still relevant.
+    rem_queries = end_queries;
+    for (const interval_type& interval : overlapping_by_desc_right_) {
+        const auto& interval_right = Traits::get_right(interval);
+        // Find the first query point which is > the right side of the interval.
+        auto first_non_match = std::partition_point(
+                partition_point, rem_queries, [&](const QueryPointType& query_point) {
+                    return Traits::compare(query_point, interval_right) <= 0;
+                });
+        for (auto it = partition_point; it != first_non_match; ++it) {
+            cb(*it, interval);
+        }
+        // Same logic as above: if a query point was fully right of 'interval',
+        // then it will be fully right of all following intervals because they are
+        // sorted by descending-right.
+        rem_queries = std::move(first_non_match);
+    }
+
+    if (right_ != NULL) {
+        while (partition_point != end_queries &&
+               Traits::compare(*partition_point, split_point_) == 0) {
+            ++partition_point;
+        }
+        right_->ForEachIntervalContainingPoints(partition_point, end_queries, cb);
+    }
+}
+
+template <class Traits>
+template <class QueryPointType>
+void ITNode<Traits>::FindContainingPoint(const QueryPointType& query,
+                                         IntervalVector* results) const {
+    int cmp = Traits::compare(query, split_point_);
+    if (cmp < 0) {
+        // None of the intervals in right_ may intersect this.
+        if (left_ != NULL) {
+            left_->FindContainingPoint(query, results);
+        }
+
+        // Any intervals which start before the query point and overlap the split point
+        // must therefore contain the query point.
+        auto p = std::partition_point(
+                overlapping_by_asc_left_.cbegin(), overlapping_by_asc_left_.cend(),
+                [&](const interval_type& interval) {
+                    return Traits::compare(Traits::get_left(interval), query) <= 0;
+                });
+        results->insert(results->end(), overlapping_by_asc_left_.cbegin(), p);
+    } else if (cmp > 0) {
+        // None of the intervals in left_ may intersect this.
+        if (right_ != NULL) {
+            right_->FindContainingPoint(query, results);
+        }
+
+        // Any intervals which end after the query point and overlap the split point
+        // must therefore contain the query point.
+        auto p = std::partition_point(
+                overlapping_by_desc_right_.cbegin(), overlapping_by_desc_right_.cend(),
+                [&](const interval_type& interval) {
+                    return Traits::compare(Traits::get_right(interval), query) >= 0;
+                });
+        results->insert(results->end(), overlapping_by_desc_right_.cbegin(), p);
+    } else {
+        DCHECK_EQ(cmp, 0);
+        // The query is exactly our split point -- in this case we've already got
+        // the computed list of overlapping intervals.
+        results->insert(results->end(), overlapping_by_asc_left_.begin(),
+                        overlapping_by_asc_left_.end());
+    }
+}
+
+template <class Traits>
+template <class QueryPointType>
+void ITNode<Traits>::FindIntersectingInterval(const QueryPointType& lower_bound,
+                                              const QueryPointType& upper_bound,
+                                              IntervalVector* results) const {
+    if (Traits::compare(upper_bound, split_point_, POSITIVE_INFINITY) <= 0) {
+        // The interval is fully left of the split point and with split point.
+        // So, it may not overlap with any in 'right_'
+        if (left_ != NULL) {
+            left_->FindIntersectingInterval(lower_bound, upper_bound, results);
+        }
+
+        // Any interval whose left edge is < the query interval's right edge
+        // intersect the query interval. 'std::partition_point' returns the first
+        // such interval which does not meet that criterion, so we insert all
+        // up to that point.
+        auto first_greater = std::partition_point(
+                overlapping_by_asc_left_.cbegin(), overlapping_by_asc_left_.cend(),
+                [&](const interval_type& interval) {
+                    return Traits::compare(Traits::get_left(interval), upper_bound,
+                                           POSITIVE_INFINITY) < 0;
+                });
+        results->insert(results->end(), overlapping_by_asc_left_.cbegin(), first_greater);
+    } else if (Traits::compare(lower_bound, split_point_, NEGATIVE_INFINITY) > 0) {
+        // The interval is fully right of the split point. So, it may not overlap
+        // with any in 'left_'.
+        if (right_ != NULL) {
+            right_->FindIntersectingInterval(lower_bound, upper_bound, results);
+        }
+
+        // Any interval whose right edge is >= the query interval's left edge
+        // intersect the query interval. 'std::partition_point' returns the first
+        // such interval which does not meet that criterion, so we insert all
+        // up to that point.
+        auto first_lesser = std::partition_point(
+                overlapping_by_desc_right_.cbegin(), overlapping_by_desc_right_.cend(),
+                [&](const interval_type& interval) {
+                    return Traits::compare(Traits::get_right(interval), lower_bound,
+                                           NEGATIVE_INFINITY) >= 0;
+                });
+        results->insert(results->end(), overlapping_by_desc_right_.cbegin(), first_lesser);
+    } else {
+        // The query interval contains the split point. Therefore all other intervals
+        // which also contain the split point are intersecting.
+        results->insert(results->end(), overlapping_by_asc_left_.begin(),
+                        overlapping_by_asc_left_.end());
+
+        // The query interval may _also_ intersect some in either child.
+        if (left_ != NULL) {
+            left_->FindIntersectingInterval(lower_bound, upper_bound, results);
+        }
+        if (right_ != NULL) {
+            right_->FindIntersectingInterval(lower_bound, upper_bound, results);
+        }
+    }
+}
+
+} // namespace interval_tree_internal
+
+} // namespace doris
diff --git a/be/src/util/interval_tree.h b/be/src/util/interval_tree.h
new file mode 100644
index 0000000000..dd978e8354
--- /dev/null
+++ b/be/src/util/interval_tree.h
@@ -0,0 +1,159 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+//
+// Implements an Interval Tree. See http://en.wikipedia.org/wiki/Interval_tree
+// or CLRS for a full description of the data structure.
+//
+// This file is copied from
+// https://github.com/apache/kudu/blob/master/src/kudu/util/interval_tree.h
+// and modified by Doris
+//
+// Callers of this class should also include interval_tree-inl.h for function
+// definitions.
+
+#pragma once
+
+#include <glog/logging.h>
+
+#include <vector>
+
+#include "gutil/macros.h"
+
+namespace doris {
+
+namespace interval_tree_internal {
+template <class Traits>
+class ITNode;
+}
+
+// End point type when boost::none.
+enum EndpointIfNone { POSITIVE_INFINITY, NEGATIVE_INFINITY };
+
+// Implements an Interval Tree.
+//
+// An Interval Tree is a data structure which stores a set of intervals and supports
+// efficient searches to determine which intervals in that set overlap a query
+// point or interval. These operations are O(lg n + k) where 'n' is the number of
+// intervals in the tree and 'k' is the number of results returned for a given query.
+//
+// This particular implementation is a static tree -- intervals may not be added or
+// removed once the tree is instantiated.
+//
+// This class also assumes that all intervals are "closed" intervals -- the intervals
+// are inclusive of their start and end points.
+//
+// The Traits class should have the following members:
+//   Traits::point_type
+//     a typedef for what a "point" in the range is
+//
+//   Traits::interval_type
+//     a typedef for an interval
+//
+//   static point_type get_left(const interval_type &)
+//   static point_type get_right(const interval_type &)
+//     accessors which fetch the left and right bound of the interval, respectively.
+//
+//   static int compare(const point_type &a, const point_type &b)
+//     return < 0 if a < b, 0 if a == b, > 0 if a > b
+//
+// See interval_tree-test.cc for an example Traits class for 'int' ranges.
+template <class Traits>
+class IntervalTree {
+private:
+    // Import types from the traits class to make code more readable.
+    typedef typename Traits::interval_type interval_type;
+    typedef typename Traits::point_type point_type;
+
+    // And some convenience types.
+    typedef std::vector<interval_type> IntervalVector;
+    typedef interval_tree_internal::ITNode<Traits> node_type;
+
+public:
+    // Construct an Interval Tree containing the given set of intervals.
+    explicit IntervalTree(const IntervalVector& intervals);
+
+    ~IntervalTree();
+
+    // Find all intervals in the tree which contain the query point.
+    // The resulting intervals are added to the 'results' vector.
+    // The vector is not cleared first.
+    //
+    // NOTE: 'QueryPointType' is usually point_type, but can be any other
+    // type for which there exists the appropriate Traits::Compare(...) method.
+    template <class QueryPointType>
+    void FindContainingPoint(const QueryPointType& query, IntervalVector* results) const;
+
+    // For each of the query points in the STL container 'queries', find all
+    // intervals in the tree which may contain those points. Calls 'cb(point, interval)'
+    // for each such interval.
+    //
+    // The points in the query container must be comparable to 'point_type'
+    // using Traits::Compare().
+    //
+    // The implementation sequences the calls to 'cb' with the following guarantees:
+    // 1) all of the results corresponding to a given interval will be yielded in at
+    //    most two "groups" of calls (i.e. sub-sequences of calls with the same interval).
+    // 2) within each "group" of calls, the query points will be in ascending order.
+    //
+    // For example, the callback sequence may be:
+    //
+    //  cb(q1, interval_1) -
+    //  cb(q2, interval_1)  | first group of interval_1
+    //  cb(q6, interval_1)  |
+    //  cb(q7, interval_1) -
+    //
+    //  cb(q2, interval_2) -
+    //  cb(q3, interval_2)  | first group of interval_2
+    //  cb(q4, interval_2) -
+    //
+    //  cb(q3, interval_1) -
+    //  cb(q4, interval_1)  | second group of interval_1
+    //  cb(q5, interval_1) -
+    //
+    //  cb(q2, interval_3) -
+    //  cb(q3, interval_3)  | first group of interval_3
+    //  cb(q4, interval_3) -
+    //
+    //  cb(q5, interval_2) -
+    //  cb(q6, interval_2)  | second group of interval_2
+    //  cb(q7, interval_2) -
+    //
+    // REQUIRES: The input points must be pre-sorted or else this will return invalid
+    // results.
+    template <class Callback, class QueryContainer>
+    void ForEachIntervalContainingPoints(const QueryContainer& queries, const Callback& cb) const;
+
+    // Find all intervals in the tree which intersect the given interval.
+    // The resulting intervals are added to the 'results' vector.
+    // The vector is not cleared first.
+    template <class QueryPointType>
+    void FindIntersectingInterval(const QueryPointType& lower_bound,
+                                  const QueryPointType& upper_bound, IntervalVector* results) const;
+
+private:
+    static void Partition(const IntervalVector& in, point_type* split_point, IntervalVector* left,
+                          IntervalVector* overlapping, IntervalVector* right);
+
+    // Create a node containing the given intervals, recursively splitting down the tree.
+    static node_type* CreateNode(const IntervalVector& intervals);
+
+    node_type* root_;
+
+    DISALLOW_COPY_AND_ASSIGN(IntervalTree);
+};
+
+} // namespace doris
diff --git a/be/test/CMakeLists.txt b/be/test/CMakeLists.txt
index 6b8648b9ba..6ee7bccaed 100644
--- a/be/test/CMakeLists.txt
+++ b/be/test/CMakeLists.txt
@@ -319,6 +319,7 @@ set(UTIL_TEST_FILES
     util/array_parser_test.cpp
     util/quantile_state_test.cpp
     util/hdfs_storage_backend_test.cpp
+    util/interval_tree_test.cpp
 )
 set(VEC_TEST_FILES
     vec/aggregate_functions/agg_test.cpp
diff --git a/be/test/util/interval_tree_test.cpp b/be/test/util/interval_tree_test.cpp
new file mode 100644
index 0000000000..4fb8ed4197
--- /dev/null
+++ b/be/test/util/interval_tree_test.cpp
@@ -0,0 +1,392 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+//
+// This file is copied from
+// https://github.com/apache/kudu/blob/master/src/kudu/util/interval_tree-test.cc
+// and modified by Doris
+
+#include "util/interval_tree.h"
+
+#include <glog/logging.h>
+#include <gtest/gtest.h>
+
+#include <algorithm>
+#include <map>
+#include <memory>
+#include <ostream>
+#include <string>
+#include <tuple> // IWYU pragma: keep
+#include <utility>
+#include <vector>
+
+#include "gutil/stringprintf.h"
+#include "gutil/strings/substitute.h"
+#include "testutil/test_util.h"
+#include "util/interval_tree-inl.h"
+
+using std::pair;
+using std::string;
+using std::vector;
+using strings::Substitute;
+
+namespace doris {
+
+// Test harness.
+class TestIntervalTree : public testing::Test {};
+
+// Simple interval class for integer intervals.
+struct IntInterval {
+    IntInterval(int left, int right, int id = -1) : left(left), right(right), id(id) {}
+
+    // std::nullopt means infinity.
+    // [left,  right] is closed interval.
+    // [lower, upper) is half-open interval, so the upper is exclusive.
+    bool Intersects(const std::optional<int>& lower, const std::optional<int>& upper) const {
+        if (lower == std::nullopt && upper == std::nullopt) {
+            //         [left, right]
+            //            |     |
+            // [-OO,                      +OO)
+        } else if (lower == std::nullopt) {
+            //         [left, right]
+            //            |
+            // [-OO,    upper)
+            if (*upper <= this->left) return false;
+        } else if (upper == std::nullopt) {
+            //         [left, right]                          /
+            //                     \                          /
+            //                      [lower, +OO)              /
+            if (*lower > this->right) return false;
+        } else {
+            //         [left, right]                          /
+            //                     \                          /
+            //                      [lower, upper)            /
+            if (*lower > this->right) return false;
+            //         [left, right]                          /
+            //            |                                   /
+            // [lower,  upper)                                /
+            if (*upper <= this->left) return false;
+        }
+        return true;
+    }
+
+    string ToString() const { return strings::Substitute("[$0, $1]($2) ", left, right, id); }
+
+    int left, right, id;
+};
+
+// A wrapper around an int which can be compared with IntTraits::compare()
+// but also can keep a counter of how many times it has been compared. Used
+// for TestBigO below.
+struct CountingQueryPoint {
+    explicit CountingQueryPoint(int v) : val(v), count(new int(0)) {}
+
+    int val;
+    std::shared_ptr<int> count;
+};
+
+// Traits definition for intervals made up of ints on either end.
+struct IntTraits {
+    typedef int point_type;
+    typedef IntInterval interval_type;
+    static point_type get_left(const IntInterval& x) { return x.left; }
+    static point_type get_right(const IntInterval& x) { return x.right; }
+    static int compare(int a, int b) {
+        if (a < b) return -1;
+        if (a > b) return 1;
+        return 0;
+    }
+
+    static int compare(const CountingQueryPoint& q, int b) {
+        (*q.count)++;
+        return compare(q.val, b);
+    }
+    static int compare(int a, const CountingQueryPoint& b) { return -compare(b, a); }
+
+    static int compare(const std::optional<int>& a, const int b, const EndpointIfNone& type) {
+        if (a == std::nullopt) {
+            return ((POSITIVE_INFINITY == type) ? 1 : -1);
+        }
+
+        return compare(*a, b);
+    }
+
+    static int compare(const int a, const std::optional<int>& b, const EndpointIfNone& type) {
+        return -compare(b, a, type);
+    }
+};
+
+// Compare intervals in an arbitrary but consistent way - this is only
+// used for verifying that the two algorithms come up with the same results.
+// It's not necessary to define this to use an interval tree.
+static bool CompareIntervals(const IntInterval& a, const IntInterval& b) {
+    return std::make_tuple(a.left, a.right, a.id) < std::make_tuple(b.left, b.right, b.id);
+}
+
+// Stringify a list of int intervals, for easy test error reporting.
+static string Stringify(const vector<IntInterval>& intervals) {
+    string ret;
+    bool first = true;
+    for (const IntInterval& interval : intervals) {
+        if (!first) {
+            ret.append(",");
+        }
+        ret.append(interval.ToString());
+    }
+    return ret;
+}
+
+// Find any intervals in 'intervals' which contain 'query_point' by brute force.
+static void FindContainingBruteForce(const vector<IntInterval>& intervals, int query_point,
+                                     vector<IntInterval>* results) {
+    for (const IntInterval& i : intervals) {
+        if (query_point >= i.left && query_point <= i.right) {
+            results->push_back(i);
+        }
+    }
+}
+
+// Find any intervals in 'intervals' which intersect 'query_interval' by brute force.
+static void FindIntersectingBruteForce(const vector<IntInterval>& intervals,
+                                       const std::optional<int>& lower,
+                                       const std::optional<int>& upper,
+                                       vector<IntInterval>* results) {
+    for (const IntInterval& i : intervals) {
+        if (i.Intersects(lower, upper)) {
+            results->push_back(i);
+        }
+    }
+}
+
+// Verify that IntervalTree::FindContainingPoint yields the same results as the naive
+// brute-force O(n) algorithm.
+static void VerifyFindContainingPoint(const vector<IntInterval>& all_intervals,
+                                      const IntervalTree<IntTraits>& tree, int query_point) {
+    vector<IntInterval> results;
+    tree.FindContainingPoint(query_point, &results);
+    std::sort(results.begin(), results.end(), CompareIntervals);
+
+    vector<IntInterval> brute_force;
+    FindContainingBruteForce(all_intervals, query_point, &brute_force);
+    std::sort(brute_force.begin(), brute_force.end(), CompareIntervals);
+
+    SCOPED_TRACE(Stringify(all_intervals) + StringPrintf(" {q=%d}", query_point));
+    EXPECT_EQ(Stringify(brute_force), Stringify(results));
+}
+
+// Verify that IntervalTree::FindIntersectingInterval yields the same results as the naive
+// brute-force O(n) algorithm.
+static void VerifyFindIntersectingInterval(const vector<IntInterval>& all_intervals,
+                                           const IntervalTree<IntTraits>& tree,
+                                           const IntInterval& query_interval) {
+    const auto& Process = [&](const std::optional<int>& lower, const std::optional<int>& upper) {
+        vector<IntInterval> results;
+        tree.FindIntersectingInterval(lower, upper, &results);
+        std::sort(results.begin(), results.end(), CompareIntervals);
+
+        vector<IntInterval> brute_force;
+        FindIntersectingBruteForce(all_intervals, lower, upper, &brute_force);
+        std::sort(brute_force.begin(), brute_force.end(), CompareIntervals);
+        EXPECT_EQ(Stringify(brute_force), Stringify(results));
+    };
+
+    {
+        // [lower, upper)
+        std::optional<int> lower = query_interval.left;
+        std::optional<int> upper = query_interval.right;
+        SCOPED_TRACE(Stringify(all_intervals) + StringPrintf(" {q=[%d, %d)}", *lower, *upper));
+        Process(lower, upper);
+    }
+
+    {
+        // [-OO, upper)
+        std::optional<int> lower = std::nullopt;
+        std::optional<int> upper = query_interval.right;
+        SCOPED_TRACE(Stringify(all_intervals) + StringPrintf(" {q=[-OO, %d)}", *upper));
+        Process(lower, upper);
+    }
+
+    {
+        // [lower, +OO)
+        std::optional<int> lower = query_interval.left;
+        std::optional<int> upper = std::nullopt;
+        SCOPED_TRACE(Stringify(all_intervals) + StringPrintf(" {q=[%d, +OO)}", *lower));
+        Process(lower, upper);
+    }
+
+    {
+        // [-OO, +OO)
+        std::optional<int> lower = query_interval.left;
+        std::optional<int> upper = std::nullopt;
+        SCOPED_TRACE(Stringify(all_intervals) + StringPrintf(" {q=[-OO, +OO)}"));
+        Process(lower, upper);
+    }
+}
+
+static vector<IntInterval> CreateRandomIntervals(int n = 100) {
+    vector<IntInterval> intervals;
+    for (int i = 0; i < n; i++) {
+        int l = rand_rng_int(0, 100);    // NOLINT(runtime/threadsafe_fn)
+        int r = l + rand_rng_int(0, 20); // NOLINT(runtime/threadsafe_fn)
+        intervals.emplace_back(l, r, i);
+    }
+    return intervals;
+}
+
+TEST_F(TestIntervalTree, TestBasic) {
+    vector<IntInterval> intervals;
+    intervals.emplace_back(1, 2, 1);
+    intervals.emplace_back(3, 4, 2);
+    intervals.emplace_back(1, 4, 3);
+    IntervalTree<IntTraits> t(intervals);
+
+    for (int i = 0; i <= 5; i++) {
+        VerifyFindContainingPoint(intervals, t, i);
+
+        for (int j = i; j <= 5; j++) {
+            VerifyFindIntersectingInterval(intervals, t, IntInterval(i, j, 0));
+        }
+    }
+}
+
+TEST_F(TestIntervalTree, TestRandomized) {
+    // Generate 100 random intervals spanning 0-200 and build an interval tree from them.
+    vector<IntInterval> intervals = CreateRandomIntervals();
+    IntervalTree<IntTraits> t(intervals);
+
+    // Test that we get the correct result on every possible query.
+    for (int i = -1; i < 201; i++) {
+        VerifyFindContainingPoint(intervals, t, i);
+    }
+
+    // Test that we get the correct result for random intervals
+    for (int i = 0; i < 100; i++) {
+        int l = rand_rng_int(0, 100);     // NOLINT(runtime/threadsafe_fn)
+        int r = rand_rng_int(l, l + 100); // NOLINT(runtime/threadsafe_fn)
+        VerifyFindIntersectingInterval(intervals, t, IntInterval(l, r));
+    }
+}
+
+TEST_F(TestIntervalTree, TestEmpty) {
+    vector<IntInterval> empty;
+    IntervalTree<IntTraits> t(empty);
+
+    VerifyFindContainingPoint(empty, t, 1);
+    VerifyFindIntersectingInterval(empty, t, IntInterval(1, 2, 0));
+}
+
+TEST_F(TestIntervalTree, TestBigO) {
+#ifndef NDEBUG
+    LOG(WARNING) << "big-O results are not valid if DCHECK is enabled";
+    return;
+#endif
+    LOG(INFO) << "num_int\tnum_q\tresults\tsimple\tbatch";
+    for (int num_intervals = 1; num_intervals < 2000; num_intervals *= 2) {
+        vector<IntInterval> intervals = CreateRandomIntervals(num_intervals);
+        IntervalTree<IntTraits> t(intervals);
+        for (int num_queries = 1; num_queries < 2000; num_queries *= 2) {
+            vector<CountingQueryPoint> queries;
+            for (int i = 0; i < num_queries; i++) {
+                queries.emplace_back(rand_rng_int(0, 100));
+            }
+            std::sort(queries.begin(), queries.end(),
+                      [](const CountingQueryPoint& a, const CountingQueryPoint& b) {
+                          return a.val < b.val;
+                      });
+
+            // Test using batch algorithm.
+            int num_results_batch = 0;
+            t.ForEachIntervalContainingPoints(
+                    queries, [&](CountingQueryPoint query_point, const IntInterval& interval) {
+                        num_results_batch++;
+                    });
+            int num_comparisons_batch = 0;
+            for (const auto& q : queries) {
+                num_comparisons_batch += *q.count;
+                *q.count = 0;
+            }
+
+            // Test using one-by-one queries.
+            int num_results_simple = 0;
+            for (auto& q : queries) {
+                vector<IntInterval> tmp_intervals;
+                t.FindContainingPoint(q, &tmp_intervals);
+                num_results_simple += tmp_intervals.size();
+            }
+            int num_comparisons_simple = 0;
+            for (const auto& q : queries) {
+                num_comparisons_simple += *q.count;
+            }
+            ASSERT_EQ(num_results_simple, num_results_batch);
+
+            LOG(INFO) << num_intervals << "\t" << num_queries << "\t" << num_results_simple << "\t"
+                      << num_comparisons_simple << "\t" << num_comparisons_batch;
+        }
+    }
+}
+
+TEST_F(TestIntervalTree, TestMultiQuery) {
+    const int kNumQueries = 1;
+    vector<IntInterval> intervals = CreateRandomIntervals(10);
+    IntervalTree<IntTraits> t(intervals);
+
+    // Generate random queries.
+    vector<int> queries;
+    for (int i = 0; i < kNumQueries; i++) {
+        queries.push_back(rand_rng_int(0, 100));
+    }
+    std::sort(queries.begin(), queries.end());
+
+    vector<pair<string, int>> results_simple;
+    for (int q : queries) {
+        vector<IntInterval> tmp_intervals;
+        t.FindContainingPoint(q, &tmp_intervals);
+        for (const auto& interval : tmp_intervals) {
+            results_simple.emplace_back(interval.ToString(), q);
+        }
+    }
+
+    vector<pair<string, int>> results_batch;
+    t.ForEachIntervalContainingPoints(queries, [&](int query_point, const IntInterval& interval) {
+        results_batch.emplace_back(interval.ToString(), query_point);
+    });
+
+    // Check the property that, when the batch query points are in sorted order,
+    // the results are grouped by interval, and within each interval, sorted by
+    // query point. Each interval may have at most two groups.
+    std::optional<pair<string, int>> prev = std::nullopt;
+    std::map<string, int> intervals_seen;
+    for (int i = 0; i < results_batch.size(); i++) {
+        const auto& cur = results_batch[i];
+        // If it's another query point hitting the same interval,
+        // make sure the query points are returned in order.
+        if (prev && prev->first == cur.first) {
+            EXPECT_GE(cur.second, prev->second) << prev->first;
+        } else {
+            // It's the start of a new interval's data. Make sure that we don't
+            // see the same interval twice.
+            EXPECT_LE(++intervals_seen[cur.first], 2)
+                    << "Saw more than two groups for interval " << cur.first;
+        }
+        prev = cur;
+    }
+
+    std::sort(results_simple.begin(), results_simple.end());
+    std::sort(results_batch.begin(), results_batch.end());
+    ASSERT_EQ(results_simple, results_batch);
+}
+
+} // namespace doris


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@doris.apache.org
For additional commands, e-mail: commits-help@doris.apache.org