You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@kudu.apache.org by to...@apache.org on 2017/04/05 23:11:31 UTC

[2/3] kudu git commit: interval_tree: allow bulk queries

interval_tree: allow bulk queries

This adds an API to interval_tree to allow it to be queried in "bulk".
The caller passes a set of query points and a callback, and the tree
calls the callback once per <interval, point> pair.

This has much better CPU cache efficiency than a lot of separate
queries. Additionally, doing 'Q' queries in bulk is logarithmic in Q
rather than linear.

More importantly, this will allow rowset-wise processing of batches of
writes, which opens up big perf wins in later patches in this series.

Change-Id: Ifb4da25ca43413fbcae631a7b0f3f16062e4e408
Reviewed-on: http://gerrit.cloudera.org:8080/6481
Tested-by: Todd Lipcon <to...@apache.org>
Reviewed-by: David Ribeiro Alves <dr...@apache.org>


Project: http://git-wip-us.apache.org/repos/asf/kudu/repo
Commit: http://git-wip-us.apache.org/repos/asf/kudu/commit/df451907
Tree: http://git-wip-us.apache.org/repos/asf/kudu/tree/df451907
Diff: http://git-wip-us.apache.org/repos/asf/kudu/diff/df451907

Branch: refs/heads/master
Commit: df451907eabea62145226be5bb303a10a6c01004
Parents: 516d67e
Author: Todd Lipcon <to...@apache.org>
Authored: Tue Mar 21 17:19:42 2017 -0700
Committer: Todd Lipcon <to...@apache.org>
Committed: Wed Apr 5 23:10:27 2017 +0000

----------------------------------------------------------------------
 src/kudu/util/interval_tree-inl.h   | 135 +++++++++++++++++++++-
 src/kudu/util/interval_tree-test.cc | 192 +++++++++++++++++++++++++++----
 src/kudu/util/interval_tree.h       |  48 +++++++-
 3 files changed, 347 insertions(+), 28 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/kudu/blob/df451907/src/kudu/util/interval_tree-inl.h
----------------------------------------------------------------------
diff --git a/src/kudu/util/interval_tree-inl.h b/src/kudu/util/interval_tree-inl.h
index ec65390..7637317 100644
--- a/src/kudu/util/interval_tree-inl.h
+++ b/src/kudu/util/interval_tree-inl.h
@@ -38,7 +38,8 @@ IntervalTree<Traits>::~IntervalTree() {
 }
 
 template<class Traits>
-void IntervalTree<Traits>::FindContainingPoint(const point_type &query,
+template<class QueryPointType>
+void IntervalTree<Traits>::FindContainingPoint(const QueryPointType &query,
                                                IntervalVector *results) const {
   if (root_) {
     root_->FindContainingPoint(query, results);
@@ -46,6 +47,17 @@ void IntervalTree<Traits>::FindContainingPoint(const point_type &query,
 }
 
 template<class Traits>
+template<class Callback, class QueryContainer>
+void IntervalTree<Traits>::ForEachIntervalContainingPoints(
+    const QueryContainer& queries,
+    const Callback& cb) const {
+  if (root_) {
+    root_->ForEachIntervalContainingPoints(queries.begin(), queries.end(), cb);
+  }
+}
+
+
+template<class Traits>
 void IntervalTree<Traits>::FindIntersectingInterval(const interval_type &query,
                                                     IntervalVector *results) const {
   if (root_) {
@@ -153,9 +165,18 @@ class ITNode {
   ~ITNode();
 
   // See IntervalTree::FindContainingPoint(...)
-  void FindContainingPoint(const point_type &query,
+  template<class QueryPointType>
+  void FindContainingPoint(const QueryPointType &query,
                            IntervalVector *results) const;
 
+  // See IntervalTree::ForEachIntervalContainingPoints().
+  // We use iterators here since as recursion progresses down the tree, we
+  // process sub-sequences of the original set of query points.
+  template<class Callback, class ItType>
+  void ForEachIntervalContainingPoints(ItType begin_queries,
+                                       ItType end_queries,
+                                       const Callback& cb) const;
+
   // See IntervalTree::FindIntersectingInterval(...)
   void FindIntersectingInterval(const interval_type &query,
                                 IntervalVector *results) const;
@@ -214,7 +235,115 @@ ITNode<Traits>::~ITNode() {
 }
 
 template<class Traits>
-void ITNode<Traits>::FindContainingPoint(const point_type &query,
+template<class Callback, class ItType>
+void ITNode<Traits>::ForEachIntervalContainingPoints(ItType begin_queries,
+                                                     ItType end_queries,
+                                                     const Callback& cb) const {
+  if (begin_queries == end_queries) return;
+
+  typedef decltype(*begin_queries) QueryPointType;
+  const auto& partitioner = [&](const QueryPointType& query_point) {
+    return Traits::compare(query_point, split_point_) < 0;
+  };
+
+  // Partition the query points into those less than the split_point_ and those greater
+  // than or equal to the split_point_. Because the input queries are already sorted, we
+  // can use 'std::partition_point' instead of 'std::partition'.
+  //
+  // The resulting 'partition_point' is the first query point in the second group.
+  //
+  // Complexity: O(log(number of query points))
+  DCHECK(std::is_partitioned(begin_queries, end_queries, partitioner));
+  auto partition_point = std::partition_point(begin_queries, end_queries, partitioner);
+
+  // Recurse left: any query points left of the split point may intersect
+  // with non-overlapping intervals fully-left of our split point.
+  if (left_ != NULL) {
+    left_->ForEachIntervalContainingPoints(begin_queries, partition_point, cb);
+  }
+
+  // Handle the query points < split_point
+  //
+  //      split_point_
+  //         |
+  //   [------]         \
+  //     [-------]       | overlapping_by_asc_left_
+  //       [--------]   /
+  // Q   Q      Q
+  // ^   ^      \___ not handled (right of split_point_)
+  // |   |
+  // \___\___ these points will be handled here
+  //
+
+  // Lower bound of query points still relevant.
+  auto rem_queries = begin_queries;
+  for (const interval_type &interval : overlapping_by_asc_left_) {
+    const auto& interval_left = Traits::get_left(interval);
+    // Find those query points which are right of the left side of the interval.
+    // 'first_match' here is the first query point >= interval_left.
+    // Complexity: O(log(num_queries))
+    //
+    // TODO(todd): The non-batched implementation is O(log(num_intervals) * num_queries)
+    // whereas this loop ends up O(num_intervals * log(num_queries)). So, for
+    // small numbers of queries this is not the fastest way to structure these loops.
+    auto first_match = std::partition_point(
+        rem_queries, partition_point,
+        [&](const QueryPointType& query_point) {
+          return Traits::compare(query_point, interval_left) < 0;
+        });
+    for (auto it = first_match; it != partition_point; ++it) {
+      cb(*it, interval);
+    }
+    // Since the intervals are sorted in ascending-left order, we can start
+    // the search for the next interval at the first match in this interval.
+    // (any query point which was left of the current interval will also be left
+    // of all future intervals).
+    rem_queries = std::move(first_match);
+  }
+
+  // Handle the query points >= split_point
+  //
+  //    split_point_
+  //       |
+  //     [--------]   \
+  //   [-------]       | overlapping_by_desc_right_
+  // [------]         /
+  //   Q   Q      Q
+  //   |    \______\___ these points will be handled here
+  //   |
+  //   \___ not handled (left of split_point_)
+
+  // Upper bound of query points still relevant.
+  rem_queries = end_queries;
+  for (const interval_type &interval : overlapping_by_desc_right_) {
+    const auto& interval_right = Traits::get_right(interval);
+    // Find the first query point which is > the right side of the interval.
+    auto first_non_match = std::partition_point(
+        partition_point, rem_queries,
+        [&](const QueryPointType& query_point) {
+          return Traits::compare(query_point, interval_right) <= 0;
+          });
+    for (auto it = partition_point; it != first_non_match; ++it) {
+      cb(*it, interval);
+    }
+    // Same logic as above: if a query point was fully right of 'interval',
+    // then it will be fully right of all following intervals because they are
+    // sorted by descending-right.
+    rem_queries = std::move(first_non_match);
+  }
+
+  if (right_ != NULL) {
+    while (partition_point != end_queries &&
+           Traits::compare(*partition_point, split_point_) == 0) {
+      ++partition_point;
+    }
+    right_->ForEachIntervalContainingPoints(partition_point, end_queries, cb);
+  }
+}
+
+template<class Traits>
+template<class QueryPointType>
+void ITNode<Traits>::FindContainingPoint(const QueryPointType &query,
                                          IntervalVector *results) const {
   int cmp = Traits::compare(query, split_point_);
   if (cmp < 0) {

http://git-wip-us.apache.org/repos/asf/kudu/blob/df451907/src/kudu/util/interval_tree-test.cc
----------------------------------------------------------------------
diff --git a/src/kudu/util/interval_tree-test.cc b/src/kudu/util/interval_tree-test.cc
index 09c3015..34a1d07 100644
--- a/src/kudu/util/interval_tree-test.cc
+++ b/src/kudu/util/interval_tree-test.cc
@@ -17,17 +17,24 @@
 
 // All rights reserved.
 
-#include <gtest/gtest.h>
 #include <stdlib.h>
 
 #include <algorithm>
+#include <tuple>
+
+#include <boost/optional.hpp>
+#include <glog/stl_logging.h>
+#include <gtest/gtest.h>
 
 #include "kudu/gutil/stringprintf.h"
+#include "kudu/gutil/strings/substitute.h"
 #include "kudu/util/interval_tree.h"
 #include "kudu/util/interval_tree-inl.h"
 #include "kudu/util/test_util.h"
 
 using std::vector;
+using std::string;
+using strings::Substitute;
 
 namespace kudu {
 
@@ -37,7 +44,11 @@ class TestIntervalTree : public KuduTest {
 
 // Simple interval class for integer intervals.
 struct IntInterval {
-  IntInterval(int left_, int right_) : left(left_), right(right_) {}
+  IntInterval(int left, int right, int id = -1)
+      : left(left),
+        right(right),
+        id(id) {
+  }
 
   bool Intersects(const IntInterval &other) const {
     if (other.left > right) return false;
@@ -45,7 +56,24 @@ struct IntInterval {
     return true;
   }
 
-  int left, right;
+  string ToString() const {
+    return strings::Substitute("[$0, $1]($2) ", left, right, id);
+  }
+
+  int left, right, id;
+};
+
+// A wrapper around an int which can be compared with IntTraits::compare()
+// but also can keep a counter of how many times it has been compared. Used
+// for TestBigO below.
+struct CountingQueryPoint {
+  explicit CountingQueryPoint(int v)
+      : val(v),
+        count(new int(0)) {
+  }
+
+  int val;
+  std::shared_ptr<int> count;
 };
 
 // Traits definition for intervals made up of ints on either end.
@@ -63,17 +91,23 @@ struct IntTraits {
     if (a > b) return 1;
     return 0;
   }
+
+  static int compare(const CountingQueryPoint& q, int b) {
+    (*q.count)++;
+    return compare(q.val, b);
+  }
+  static int compare(int a, const CountingQueryPoint& b) {
+    return -compare(b, a);
+  }
+
 };
 
-// Compare intervals in a consistent way - this is only used for verifying
-// that the two algorithms come up with the same results. It's not necessary
-// to define this to use an interval tree.
+// Compare intervals in an arbitrary but consistent way - this is only
+// used for verifying that the two algorithms come up with the same results.
+// It's not necessary to define this to use an interval tree.
 static bool CompareIntervals(const IntInterval &a, const IntInterval &b) {
-  if (a.left < b.left) return true;
-  if (a.left > b.left) return false;
-  if (a.right < b.right) return true;
-  if (b.right > b.right) return true;
-  return false; // equal
+  return std::make_tuple(a.left, a.right, a.id) <
+    std::make_tuple(b.left, b.right, b.id);
 }
 
 // Stringify a list of int intervals, for easy test error reporting.
@@ -84,7 +118,7 @@ static string Stringify(const vector<IntInterval> &intervals) {
     if (!first) {
       ret.append(",");
     }
-    StringAppendF(&ret, "[%d, %d]", interval.left, interval.right);
+    ret.append(interval.ToString());
   }
   return ret;
 }
@@ -148,19 +182,28 @@ static void VerifyFindIntersectingInterval(const vector<IntInterval> all_interva
   EXPECT_EQ(Stringify(brute_force), Stringify(results));
 }
 
+static vector<IntInterval> CreateRandomIntervals(int n = 100) {
+  vector<IntInterval> intervals;
+  for (int i = 0; i < n; i++) {
+    int l = rand() % 100; // NOLINT(runtime/threadsafe_fn)
+    int r = l + rand() % 20; // NOLINT(runtime/threadsafe_fn)
+    intervals.push_back(IntInterval(l, r, i));
+  }
+  return intervals;
+}
 
 TEST_F(TestIntervalTree, TestBasic) {
   vector<IntInterval> intervals;
-  intervals.push_back(IntInterval(1, 2));
-  intervals.push_back(IntInterval(3, 4));
-  intervals.push_back(IntInterval(1, 4));
+  intervals.push_back(IntInterval(1, 2, 1));
+  intervals.push_back(IntInterval(3, 4, 2));
+  intervals.push_back(IntInterval(1, 4, 3));
   IntervalTree<IntTraits> t(intervals);
 
   for (int i = 0; i <= 5; i++) {
     VerifyFindContainingPoint(intervals, t, i);
 
     for (int j = i; j <= 5; j++) {
-      VerifyFindIntersectingInterval(intervals, t, IntInterval(i, j));
+      VerifyFindIntersectingInterval(intervals, t, IntInterval(i, j, 0));
     }
   }
 }
@@ -169,12 +212,7 @@ TEST_F(TestIntervalTree, TestRandomized) {
   SeedRandom();
 
   // Generate 100 random intervals spanning 0-200 and build an interval tree from them.
-  vector<IntInterval> intervals;
-  for (int i = 0; i < 100; i++) {
-    int l = rand() % 100; // NOLINT(runtime/threadsafe_fn)
-    int r = l + rand() % 100; // NOLINT(runtime/threadsafe_fn)
-    intervals.push_back(IntInterval(l, r));
-  }
+  vector<IntInterval> intervals = CreateRandomIntervals();
   IntervalTree<IntTraits> t(intervals);
 
   // Test that we get the correct result on every possible query.
@@ -195,7 +233,115 @@ TEST_F(TestIntervalTree, TestEmpty) {
   IntervalTree<IntTraits> t(empty);
 
   VerifyFindContainingPoint(empty, t, 1);
-  VerifyFindIntersectingInterval(empty, t, IntInterval(1, 2));
+  VerifyFindIntersectingInterval(empty, t, IntInterval(1, 2, 0));
+}
+
+TEST_F(TestIntervalTree, TestBigO) {
+#ifndef NDEBUG
+  LOG(WARNING) << "big-O results are not valid if DCHECK is enabled";
+  return;
+#endif
+  SeedRandom();
+
+  LOG(INFO) << "num_int\tnum_q\tresults\tsimple\tbatch";
+  for (int num_intervals = 1; num_intervals < 2000; num_intervals *= 2) {
+    vector<IntInterval> intervals = CreateRandomIntervals(num_intervals);
+    IntervalTree<IntTraits> t(intervals);
+    for (int num_queries = 1; num_queries < 2000; num_queries *= 2) {
+      vector<CountingQueryPoint> queries;
+      for (int i = 0; i < num_queries; i++) {
+        queries.emplace_back(rand() % 100);
+      }
+      std::sort(queries.begin(), queries.end(),
+                [](const CountingQueryPoint& a,
+                   const CountingQueryPoint& b) {
+                  return a.val < b.val;
+                });
+
+      // Test using batch algorithm.
+      int num_results_batch = 0;
+      t.ForEachIntervalContainingPoints(
+          queries,
+          [&](CountingQueryPoint query_point, const IntInterval& interval) {
+            num_results_batch++;
+          });
+      int num_comparisons_batch = 0;
+      for (const auto& q : queries) {
+        num_comparisons_batch += *q.count;
+        *q.count = 0;
+      }
+
+      // Test using one-by-one queries.
+      int num_results_simple = 0;
+      for (auto& q : queries) {
+        vector<IntInterval> intervals;
+        t.FindContainingPoint(q, &intervals);
+        num_results_simple += intervals.size();
+      }
+      int num_comparisons_simple = 0;
+      for (const auto& q : queries) {
+        num_comparisons_simple += *q.count;
+      }
+      ASSERT_EQ(num_results_simple, num_results_batch);
+
+      LOG(INFO) << num_intervals << "\t" << num_queries << "\t" << num_results_simple << "\t"
+                << num_comparisons_simple << "\t" << num_comparisons_batch;
+    }
+  }
+}
+
+TEST_F(TestIntervalTree, TestMultiQuery) {
+  SeedRandom();
+  const int kNumQueries = 1;
+  vector<IntInterval> intervals = CreateRandomIntervals(10);
+  IntervalTree<IntTraits> t(intervals);
+
+  // Generate random queries.
+  vector<int> queries;
+  for (int i = 0; i < kNumQueries; i++) {
+    queries.push_back(rand() % 100);
+  }
+  std::sort(queries.begin(), queries.end());
+
+  vector<pair<string, int>> results_simple;
+  for (int q : queries) {
+    vector<IntInterval> intervals;
+    t.FindContainingPoint(q, &intervals);
+    for (const auto& interval : intervals) {
+      results_simple.emplace_back(interval.ToString(), q);
+    }
+  }
+
+  vector<pair<string, int>> results_batch;
+  t.ForEachIntervalContainingPoints(
+      queries,
+      [&](int query_point, const IntInterval& interval) {
+        results_batch.emplace_back(interval.ToString(), query_point);
+      });
+
+  // Check the property that, when the batch query points are in sorted order,
+  // the results are grouped by interval, and within each interval, sorted by
+  // query point. Each interval may have at most two groups.
+  boost::optional<pair<string, int>> prev = boost::none;
+  std::map<string, int> intervals_seen;
+  for (int i = 0; i < results_batch.size(); i++) {
+    const auto& cur = results_batch[i];
+    // If it's another query point hitting the same interval,
+    // make sure the query points are returned in order.
+    if (prev && prev->first == cur.first) {
+      EXPECT_GE(cur.second, prev->second) << prev->first;
+    } else {
+      // It's the start of a new interval's data. Make sure that we don't
+      // see the same interval twice.
+      EXPECT_LE(++intervals_seen[cur.first], 2)
+          << "Saw more than two groups for interval " << cur.first;
+    }
+    prev = cur;
+  }
+
+  std::sort(results_simple.begin(), results_simple.end());
+  std::sort(results_batch.begin(), results_batch.end());
+  ASSERT_EQ(results_simple, results_batch);
 }
 
 } // namespace kudu

http://git-wip-us.apache.org/repos/asf/kudu/blob/df451907/src/kudu/util/interval_tree.h
----------------------------------------------------------------------
diff --git a/src/kudu/util/interval_tree.h b/src/kudu/util/interval_tree.h
index 8a625d1..a677528 100644
--- a/src/kudu/util/interval_tree.h
+++ b/src/kudu/util/interval_tree.h
@@ -84,15 +84,59 @@ class IntervalTree {
   // Find all intervals in the tree which contain the query point.
   // The resulting intervals are added to the 'results' vector.
   // The vector is not cleared first.
-  void FindContainingPoint(const point_type &query,
+  //
+  // NOTE: 'QueryPointType' is usually point_type, but can be any other
+  // type for which there exists the appropriate Traits::Compare(...) method.
+  template<class QueryPointType>
+  void FindContainingPoint(const QueryPointType &query,
                            IntervalVector *results) const;
 
+  // For each of the query points in the STL container 'queries', find all
+  // intervals in the tree which may contain those points. Calls 'cb(point, interval)'
+  // for each such interval.
+  //
+  // The points in the query container must be comparable to 'point_type'
+  // using Traits::Compare().
+  //
+  // The implementation sequences the calls to 'cb' with the following guarantees:
+  // 1) all of the results corresponding to a given interval will be yielded in at
+  //    most two "groups" of calls (i.e. sub-sequences of calls with the same interval).
+  // 2) within each "group" of calls, the query points will be in ascending order.
+  //
+  // For example, the callback sequence may be:
+  //
+  //  cb(q1, interval_1) -
+  //  cb(q2, interval_1)  | first group of interval_1
+  //  cb(q6, interval_1)  |
+  //  cb(q7, interval_1) -
+  //
+  //  cb(q2, interval_2) -
+  //  cb(q3, interval_2)  | first group of interval_2
+  //  cb(q4, interval_2) -
+  //
+  //  cb(q3, interval_1) -
+  //  cb(q4, interval_1)  | second group of interval_1
+  //  cb(q5, interval_1) -
+  //
+  //  cb(q2, interval_3) -
+  //  cb(q3, interval_3)  | first group of interval_3
+  //  cb(q4, interval_3) -
+  //
+  //  cb(q5, interval_2) -
+  //  cb(q6, interval_2)  | second group of interval_2
+  //  cb(q7, interval_2) -
+  //
+  // REQUIRES: The input points must be pre-sorted or else this will return invalid
+  // results.
+  template<class Callback, class QueryContainer>
+  void ForEachIntervalContainingPoints(const QueryContainer& queries,
+                                       const Callback& cb) const;
+
   // Find all intervals in the tree which intersect the given interval.
   // The resulting intervals are added to the 'results' vector.
   // The vector is not cleared first.
   void FindIntersectingInterval(const interval_type &query,
                                 IntervalVector *results) const;
-
  private:
   static void Partition(const IntervalVector &in,
                         point_type *split_point,