You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by al...@apache.org on 2023/01/12 22:33:43 UTC

[arrow-datafusion] branch master updated: Refactor loser tree code in SortPreservingMerge per PR comments (#4407)

This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/master by this push:
     new 82bbaa3dd Refactor loser tree  code in SortPreservingMerge per PR comments (#4407)
82bbaa3dd is described below

commit 82bbaa3dd25a0b174764946be2cfd94b8eda0a68
Author: Andrew Lamb <an...@nerdnetworks.org>
AuthorDate: Thu Jan 12 23:33:35 2023 +0100

    Refactor loser tree  code in SortPreservingMerge per PR comments (#4407)
    
    * Add docstrings for Sort Preserving Merge / Loser Tree
    
    * refactor: Extract loser tree initialization into its own function
    
    * refactor: Extract loser tree update into its own function
    
    * Update types
    
    * Remove redundant update
    
    * Add TreeUpdate::Pending and TreeUpdate:Error
    
    * Simplify using Poll directly
---
 datafusion/core/src/physical_plan/sorts/cursor.rs  |   1 +
 .../physical_plan/sorts/sort_preserving_merge.rs   | 188 +++++++++++++--------
 2 files changed, 118 insertions(+), 71 deletions(-)

diff --git a/datafusion/core/src/physical_plan/sorts/cursor.rs b/datafusion/core/src/physical_plan/sorts/cursor.rs
index 51110403f..53df698c3 100644
--- a/datafusion/core/src/physical_plan/sorts/cursor.rs
+++ b/datafusion/core/src/physical_plan/sorts/cursor.rs
@@ -109,6 +109,7 @@ impl PartialOrd for SortKeyCursor {
 
 impl Ord for SortKeyCursor {
     fn cmp(&self, other: &Self) -> Ordering {
+        // Order finished cursors greater (last)
         match (self.is_finished(), other.is_finished()) {
             (true, true) => Ordering::Equal,
             (_, true) => Ordering::Less,
diff --git a/datafusion/core/src/physical_plan/sorts/sort_preserving_merge.rs b/datafusion/core/src/physical_plan/sorts/sort_preserving_merge.rs
index a5800746b..658a5f9fc 100644
--- a/datafusion/core/src/physical_plan/sorts/sort_preserving_merge.rs
+++ b/datafusion/core/src/physical_plan/sorts/sort_preserving_merge.rs
@@ -32,7 +32,7 @@ use arrow::{
     record_batch::RecordBatch,
 };
 use futures::stream::{Fuse, FusedStream};
-use futures::{Stream, StreamExt};
+use futures::{ready, Stream, StreamExt};
 use log::debug;
 use tokio::sync::mpsc;
 
@@ -321,13 +321,24 @@ pub(crate) struct SortPreservingMergeStream {
     /// Vector that holds all [`SortKeyCursor`]s
     cursors: Vec<Option<SortKeyCursor>>,
 
-    /// The loser tree that always produces the minimum cursor
+    /// A loser tree that always produces the minimum cursor
     ///
     /// Node 0 stores the top winner, Nodes 1..num_streams store
     /// the loser nodes
+    ///
+    /// This implements a "Tournament Tree" (aka Loser Tree) to keep
+    /// track of the current smallest element at the top. When the top
+    /// record is taken, the tree structure is not modified, and only
+    /// the path from bottom to top is visited, keeping the number of
+    /// comparisons close to the theoretical limit of `log(S)`.
+    ///
+    /// reference: <https://en.wikipedia.org/wiki/K-way_merge_algorithm#Tournament_Tree>
     loser_tree: Vec<usize>,
 
-    /// Identify whether the loser tree is adjusted
+    /// If the most recently yielded overall winner has been replaced
+    /// within the loser tree. A value of `false` indicates that the
+    /// overall winner has been yielded but the loser tree has not
+    /// been updated
     loser_tree_adjusted: bool,
 
     /// target batch size
@@ -558,46 +569,9 @@ impl SortPreservingMergeStream {
         if self.aborted {
             return Poll::Ready(None);
         }
-        let num_streams = self.streams.num_streams();
-
-        // Init all cursors and the loser tree in the first poll
-        if self.loser_tree.is_empty() {
-            // Ensure all non-exhausted streams have a cursor from which
-            // rows can be pulled
-            for i in 0..num_streams {
-                match futures::ready!(self.maybe_poll_stream(cx, i)) {
-                    Ok(_) => {}
-                    Err(e) => {
-                        self.aborted = true;
-                        return Poll::Ready(Some(Err(e)));
-                    }
-                }
-            }
-
-            // Init loser tree
-            self.loser_tree.resize(num_streams, usize::MAX);
-            for i in 0..num_streams {
-                let mut winner = i;
-                let mut cmp_node = (num_streams + i) / 2;
-                while cmp_node != 0 && self.loser_tree[cmp_node] != usize::MAX {
-                    let challenger = self.loser_tree[cmp_node];
-                    let challenger_win =
-                        match (&self.cursors[winner], &self.cursors[challenger]) {
-                            (None, _) => true,
-                            (_, None) => false,
-                            (Some(winner), Some(challenger)) => challenger < winner,
-                        };
-                    if challenger_win {
-                        self.loser_tree[cmp_node] = winner;
-                        winner = challenger;
-                    } else {
-                        self.loser_tree[cmp_node] = challenger;
-                    }
-                    cmp_node /= 2;
-                }
-                self.loser_tree[cmp_node] = winner;
-            }
-            self.loser_tree_adjusted = true;
+        // try to initialize the loser tree
+        if let Err(e) = ready!(self.init_loser_tree(cx)) {
+            return Poll::Ready(Some(Err(e)));
         }
 
         // NB timer records time taken on drop, so there are no
@@ -606,34 +580,9 @@ impl SortPreservingMergeStream {
         let _timer = elapsed_compute.timer();
 
         loop {
-            // Adjust the loser tree if necessary
-            if !self.loser_tree_adjusted {
-                let mut winner = self.loser_tree[0];
-                match futures::ready!(self.maybe_poll_stream(cx, winner)) {
-                    Ok(_) => {}
-                    Err(e) => {
-                        self.aborted = true;
-                        return Poll::Ready(Some(Err(e)));
-                    }
-                }
-
-                let mut cmp_node = (num_streams + winner) / 2;
-                while cmp_node != 0 {
-                    let challenger = self.loser_tree[cmp_node];
-                    let challenger_win =
-                        match (&self.cursors[winner], &self.cursors[challenger]) {
-                            (None, _) => true,
-                            (_, None) => false,
-                            (Some(winner), Some(challenger)) => challenger < winner,
-                        };
-                    if challenger_win {
-                        self.loser_tree[cmp_node] = winner;
-                        winner = challenger;
-                    }
-                    cmp_node /= 2;
-                }
-                self.loser_tree[0] = winner;
-                self.loser_tree_adjusted = true;
+            // Adjust the loser tree if necessary, returning control if needed
+            if let Err(e) = ready!(self.update_loser_tree(cx)) {
+                return Poll::Ready(Some(Err(e)));
             }
 
             let min_cursor_idx = self.loser_tree[0];
@@ -660,6 +609,103 @@ impl SortPreservingMergeStream {
             }
         }
     }
+
+    /// Attempts to initialize the loser tree with one value from each
+    /// non exhausted input, if possible.
+    ///
+    /// Returns
+    /// * Poll::Pending when more data is needed
+    /// * Poll::Ready(Ok()) on success
+    /// * Poll::Ready(Err..) if any of the inputs  errored
+    #[inline]
+    fn init_loser_tree(
+        self: &mut Pin<&mut Self>,
+        cx: &mut Context<'_>,
+    ) -> Poll<ArrowResult<()>> {
+        let num_streams = self.streams.num_streams();
+
+        if !self.loser_tree.is_empty() {
+            return Poll::Ready(Ok(()));
+        }
+
+        // Ensure all non-exhausted streams have a cursor from which
+        // rows can be pulled
+        for i in 0..num_streams {
+            if let Err(e) = ready!(self.maybe_poll_stream(cx, i)) {
+                self.aborted = true;
+                return Poll::Ready(Err(e));
+            }
+        }
+
+        // Init loser tree
+        self.loser_tree.resize(num_streams, usize::MAX);
+        for i in 0..num_streams {
+            let mut winner = i;
+            let mut cmp_node = (num_streams + i) / 2;
+            while cmp_node != 0 && self.loser_tree[cmp_node] != usize::MAX {
+                let challenger = self.loser_tree[cmp_node];
+                let challenger_win =
+                    match (&self.cursors[winner], &self.cursors[challenger]) {
+                        (None, _) => true,
+                        (_, None) => false,
+                        (Some(winner), Some(challenger)) => challenger < winner,
+                    };
+
+                if challenger_win {
+                    self.loser_tree[cmp_node] = winner;
+                    winner = challenger;
+                }
+
+                cmp_node /= 2;
+            }
+            self.loser_tree[cmp_node] = winner;
+        }
+        self.loser_tree_adjusted = true;
+        Poll::Ready(Ok(()))
+    }
+
+    /// Attempts to updated the loser tree, if possible
+    ///
+    /// Returns
+    /// * Poll::Pending when the winning unput was not ready
+    /// * Poll::Ready(Ok()) on success
+    /// * Poll::Ready(Err..) if any of the winning input erroed
+    #[inline]
+    fn update_loser_tree(
+        self: &mut Pin<&mut Self>,
+        cx: &mut Context<'_>,
+    ) -> Poll<ArrowResult<()>> {
+        if self.loser_tree_adjusted {
+            return Poll::Ready(Ok(()));
+        }
+
+        let num_streams = self.streams.num_streams();
+        let mut winner = self.loser_tree[0];
+        if let Err(e) = ready!(self.maybe_poll_stream(cx, winner)) {
+            self.aborted = true;
+            return Poll::Ready(Err(e));
+        }
+
+        // Replace overall winner by walking tree of losers
+        let mut cmp_node = (num_streams + winner) / 2;
+        while cmp_node != 0 {
+            let challenger = self.loser_tree[cmp_node];
+            let challenger_win = match (&self.cursors[winner], &self.cursors[challenger])
+            {
+                (None, _) => true,
+                (_, None) => false,
+                (Some(winner), Some(challenger)) => challenger < winner,
+            };
+            if challenger_win {
+                self.loser_tree[cmp_node] = winner;
+                winner = challenger;
+            }
+            cmp_node /= 2;
+        }
+        self.loser_tree[0] = winner;
+        self.loser_tree_adjusted = true;
+        Poll::Ready(Ok(()))
+    }
 }
 
 impl RecordBatchStream for SortPreservingMergeStream {