You are viewing a plain text version of this content. The canonical link for it is here.
Posted to github@arrow.apache.org by GitBox <gi...@apache.org> on 2022/11/21 17:53:06 UTC

[GitHub] [arrow-datafusion] tustvold commented on a diff in pull request #4301: Use tournament loser tree for k-way sort-merging

tustvold commented on code in PR #4301:
URL: https://github.com/apache/arrow-datafusion/pull/4301#discussion_r1028338480


##########
datafusion/core/src/physical_plan/sorts/sort_preserving_merge.rs:
##########
@@ -570,45 +606,57 @@ impl SortPreservingMergeStream {
         let _timer = elapsed_compute.timer();
 
         loop {
-            match self.heap.pop() {
-                Some(Reverse(mut cursor)) => {
-                    let stream_idx = cursor.stream_idx();
-                    let batch_idx = self.batches[stream_idx].len() - 1;
-                    let row_idx = cursor.advance();
-
-                    let mut cursor_finished = false;
-                    // insert the cursor back to heap if the record batch is not exhausted
-                    if !cursor.is_finished() {
-                        self.heap.push(Reverse(cursor));
-                    } else {
-                        cursor_finished = true;
-                        self.cursor_finished[stream_idx] = true;
+            // Adjust the loser tree if necessary
+            if !self.loser_tree_adjusted {
+                let mut winner = self.loser_tree[0];
+                match futures::ready!(self.maybe_poll_stream(cx, winner)) {
+                    Ok(_) => {}
+                    Err(e) => {
+                        self.aborted = true;
+                        return Poll::Ready(Some(Err(e)));
                     }
+                }
 
-                    self.in_progress.push(RowIndex {
-                        stream_idx,
-                        batch_idx,
-                        row_idx,
-                    });
-
-                    if self.in_progress.len() == self.batch_size {
-                        return Poll::Ready(Some(self.build_record_batch()));
+                let mut cmp_node = (num_streams + winner) / 2;

Review Comment:
   ```suggestion
                   // Replace overall winner by walking tree of losers
                   let mut cmp_node = (num_streams + winner) / 2;
   ```



##########
datafusion/core/src/physical_plan/sorts/sort_preserving_merge.rs:
##########
@@ -323,8 +318,17 @@ pub(crate) struct SortPreservingMergeStream {
     /// An id to uniquely identify the input stream batch
     next_batch_id: usize,
 
-    /// Heap that yields [`SortKeyCursor`] in increasing order
-    heap: BinaryHeap<Reverse<SortKeyCursor>>,
+    /// Vector that holds all [`SortKeyCursor`]s
+    cursors: Vec<Option<SortKeyCursor>>,
+
+    /// The loser tree that always produces the minimum cursor
+    ///
+    /// Node 0 stores the top winner, Nodes 1..num_streams store
+    /// the loser nodes
+    loser_tree: Vec<usize>,
+
+    /// Identify whether the loser tree is adjusted

Review Comment:
   ```suggestion
       /// Identify whether the most recently yielded overall winner has been replaced
       /// within the loser tree, a value of `false` indicates that they overall winner
       /// has been yielded but the loser tree has not been updated
   ```
   
   Or something to make it clearer what adjusted actually means.
   
   FWIW a boolean of `should_replace_winner` or something might be clearer



##########
datafusion/core/src/physical_plan/sorts/sort_preserving_merge.rs:
##########
@@ -551,17 +558,46 @@ impl SortPreservingMergeStream {
         if self.aborted {
             return Poll::Ready(None);
         }
+        let num_streams = self.streams.num_streams();
+
+        // Init all cursors and the loser tree in the first poll
+        if self.loser_tree.is_empty() {
+            // Ensure all non-exhausted streams have a cursor from which

Review Comment:
   It might be easier to follow if this method were split into a method called `init_loser_tree` with a doc comment explaining what it does



##########
datafusion/core/src/physical_plan/sorts/sort_preserving_merge.rs:
##########
@@ -570,45 +606,57 @@ impl SortPreservingMergeStream {
         let _timer = elapsed_compute.timer();
 
         loop {
-            match self.heap.pop() {
-                Some(Reverse(mut cursor)) => {
-                    let stream_idx = cursor.stream_idx();
-                    let batch_idx = self.batches[stream_idx].len() - 1;
-                    let row_idx = cursor.advance();
-
-                    let mut cursor_finished = false;
-                    // insert the cursor back to heap if the record batch is not exhausted
-                    if !cursor.is_finished() {
-                        self.heap.push(Reverse(cursor));
-                    } else {
-                        cursor_finished = true;
-                        self.cursor_finished[stream_idx] = true;
+            // Adjust the loser tree if necessary
+            if !self.loser_tree_adjusted {
+                let mut winner = self.loser_tree[0];

Review Comment:
   It might be easier to follow if this was moved into a method called `replace_loser_tree_winner`



##########
datafusion/core/src/physical_plan/sorts/cursor.rs:
##########
@@ -109,8 +109,14 @@ impl PartialOrd for SortKeyCursor {
 
 impl Ord for SortKeyCursor {
     fn cmp(&self, other: &Self) -> Ordering {
-        self.current()
-            .cmp(&other.current())
-            .then_with(|| self.stream_idx.cmp(&other.stream_idx))
+        match (self.is_finished(), other.is_finished()) {

Review Comment:
   ```suggestion
           // Order finished cursors last
           match (self.is_finished(), other.is_finished()) {
   ```



##########
datafusion/core/src/physical_plan/sorts/sort_preserving_merge.rs:
##########
@@ -570,45 +606,57 @@ impl SortPreservingMergeStream {
         let _timer = elapsed_compute.timer();
 
         loop {
-            match self.heap.pop() {
-                Some(Reverse(mut cursor)) => {
-                    let stream_idx = cursor.stream_idx();
-                    let batch_idx = self.batches[stream_idx].len() - 1;
-                    let row_idx = cursor.advance();
-
-                    let mut cursor_finished = false;
-                    // insert the cursor back to heap if the record batch is not exhausted
-                    if !cursor.is_finished() {
-                        self.heap.push(Reverse(cursor));
-                    } else {
-                        cursor_finished = true;
-                        self.cursor_finished[stream_idx] = true;
+            // Adjust the loser tree if necessary
+            if !self.loser_tree_adjusted {
+                let mut winner = self.loser_tree[0];
+                match futures::ready!(self.maybe_poll_stream(cx, winner)) {
+                    Ok(_) => {}
+                    Err(e) => {
+                        self.aborted = true;
+                        return Poll::Ready(Some(Err(e)));
                     }
+                }
 
-                    self.in_progress.push(RowIndex {
-                        stream_idx,
-                        batch_idx,
-                        row_idx,
-                    });
-
-                    if self.in_progress.len() == self.batch_size {
-                        return Poll::Ready(Some(self.build_record_batch()));
+                let mut cmp_node = (num_streams + winner) / 2;
+                while cmp_node != 0 {
+                    let challenger = self.loser_tree[cmp_node];
+                    let challenger_win =
+                        match (&self.cursors[winner], &self.cursors[challenger]) {
+                            (None, _) => true,
+                            (_, None) => false,
+                            (Some(winner), Some(challenger)) => challenger < winner,
+                        };
+                    if challenger_win {
+                        self.loser_tree[cmp_node] = winner;
+                        winner = challenger;
                     }
+                    cmp_node /= 2;
+                }
+                self.loser_tree[0] = winner;
+                self.loser_tree_adjusted = true;
+            }
 
-                    // If removed the last row from the cursor, need to fetch a new record
-                    // batch if possible, before looping round again
-                    if cursor_finished {
-                        match futures::ready!(self.maybe_poll_stream(cx, stream_idx)) {
-                            Ok(_) => {}
-                            Err(e) => {
-                                self.aborted = true;
-                                return Poll::Ready(Some(Err(e)));
-                            }
-                        }
-                    }
+            let min_cursor_idx = self.loser_tree[0];

Review Comment:
   I think this could be made easier to follow if it were written along the lines of
   
   
   ```
   let min_cursor = self.cursors[min_cursor_idx];
   if min_cursor.is_finished() {
       // All streams are exhausted
       return Poll::Ready((!self.in_progress.is_empty()).then(|| self.build_record_batch()))
   }
   
   self.loser_tree_adjusted = false;
   self.in_progress.push(...)
   if self.in_progress.len() == self.batch_size {
       return Poll::Ready(Some(self.build_record_batch()));
   }
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscribe@arrow.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org