You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by dh...@apache.org on 2023/06/30 14:28:52 UTC

[arrow-datafusion] 01/02: Add fetch to sortpreservingmergeexec

This is an automated email from the ASF dual-hosted git repository.

dheres pushed a commit to branch add_fetch_to_sort_preserving_merge_exec
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git

commit d519a8939e564cb2740b93f08e3a546da1c4fdcd
Author: Daniƫl Heres <da...@coralogix.com>
AuthorDate: Fri Jun 30 16:28:12 2023 +0200

    Add fetch to sortpreservingmergeexec
---
 .../physical_optimizer/global_sort_selection.rs    |   2 +-
 .../src/physical_optimizer/sort_enforcement.rs     |  21 +++++++++++--
 .../core/src/physical_plan/repartition/mod.rs      |   1 +
 datafusion/core/src/physical_plan/sorts/merge.rs   |  34 ++++++++++++++++-----
 datafusion/core/src/physical_plan/sorts/sort.rs    |   4 +--
 .../physical_plan/sorts/sort_preserving_merge.rs   |  25 +++++++++++++--
 datafusion/proto/proto/datafusion.proto            |   2 ++
 datafusion/proto/proto/proto_descriptor.bin        | Bin 0 -> 86986 bytes
 .../src/{generated/prost.rs => datafusion.rs}      |   3 ++
 .../{generated/pbjson.rs => datafusion.serde.rs}   |  19 ++++++++++++
 datafusion/proto/src/generated/pbjson.rs           |  19 ++++++++++++
 datafusion/proto/src/generated/prost.rs            |   3 ++
 datafusion/proto/src/physical_plan/mod.rs          |   1 +
 13 files changed, 120 insertions(+), 14 deletions(-)

diff --git a/datafusion/core/src/physical_optimizer/global_sort_selection.rs b/datafusion/core/src/physical_optimizer/global_sort_selection.rs
index 9466297d24..0b9054f89f 100644
--- a/datafusion/core/src/physical_optimizer/global_sort_selection.rs
+++ b/datafusion/core/src/physical_optimizer/global_sort_selection.rs
@@ -70,7 +70,7 @@ impl PhysicalOptimizerRule for GlobalSortSelection {
                                 Arc::new(SortPreservingMergeExec::new(
                                     sort_exec.expr().to_vec(),
                                     Arc::new(sort),
-                                ));
+                                ).with_fetch(sort_exec.fetch()));
                             Some(global_sort)
                         } else {
                             None
diff --git a/datafusion/core/src/physical_optimizer/sort_enforcement.rs b/datafusion/core/src/physical_optimizer/sort_enforcement.rs
index 719c152841..f10401877f 100644
--- a/datafusion/core/src/physical_optimizer/sort_enforcement.rs
+++ b/datafusion/core/src/physical_optimizer/sort_enforcement.rs
@@ -422,7 +422,8 @@ fn parallelize_sorts(
         update_child_to_remove_coalesce(&mut prev_layer, &mut coalesce_onwards[0])?;
         let sort_exprs = get_sort_exprs(&plan)?;
         add_sort_above(&mut prev_layer, sort_exprs.to_vec())?;
-        let spm = SortPreservingMergeExec::new(sort_exprs.to_vec(), prev_layer);
+        let sort_fetch = get_sort_fetch(&plan)?;
+        let spm = SortPreservingMergeExec::new(sort_exprs.to_vec(), prev_layer).with_fetch(sort_fetch);
         return Ok(Transformed::Yes(PlanWithCorrespondingCoalescePartitions {
             plan: Arc::new(spm),
             coalesce_onwards: vec![None],
@@ -785,7 +786,7 @@ fn remove_corresponding_sort_from_sub_plan(
     Ok(updated_plan)
 }
 
-/// Converts an [ExecutionPlan] trait object to a [PhysicalSortExpr] slice when possible.
+/// Retrieves a fetch from a `SortExec` when possible
 fn get_sort_exprs(sort_any: &Arc<dyn ExecutionPlan>) -> Result<&[PhysicalSortExpr]> {
     if let Some(sort_exec) = sort_any.as_any().downcast_ref::<SortExec>() {
         Ok(sort_exec.expr())
@@ -801,6 +802,22 @@ fn get_sort_exprs(sort_any: &Arc<dyn ExecutionPlan>) -> Result<&[PhysicalSortExp
     }
 }
 
+/// gets 
+fn get_sort_fetch(sort_any: &Arc<dyn ExecutionPlan>) -> Result<Option<usize>> {
+    if let Some(sort_exec) = sort_any.as_any().downcast_ref::<SortExec>() {
+        Ok(sort_exec.fetch())
+    } else if let Some(sort_preserving_merge_exec) =
+        sort_any.as_any().downcast_ref::<SortPreservingMergeExec>()
+    {
+        Ok(sort_preserving_merge_exec.fetch())
+    } else {
+        Err(DataFusionError::Plan(
+            "Given ExecutionPlan is not a SortExec or a SortPreservingMergeExec"
+                .to_string(),
+        ))
+    }
+}
+
 /// Compares physical ordering (output ordering of input executor) with
 /// `partitionby_exprs` and `orderby_keys`
 /// to decide whether existing ordering is sufficient to run current window executor.
diff --git a/datafusion/core/src/physical_plan/repartition/mod.rs b/datafusion/core/src/physical_plan/repartition/mod.rs
index 72ff0c3713..85225eb471 100644
--- a/datafusion/core/src/physical_plan/repartition/mod.rs
+++ b/datafusion/core/src/physical_plan/repartition/mod.rs
@@ -497,6 +497,7 @@ impl ExecutionPlan for RepartitionExec {
                 sort_exprs,
                 BaselineMetrics::new(&self.metrics, partition),
                 context.session_config().batch_size(),
+                None,
             )
         } else {
             Ok(Box::pin(RepartitionStream {
diff --git a/datafusion/core/src/physical_plan/sorts/merge.rs b/datafusion/core/src/physical_plan/sorts/merge.rs
index d8a3cdef4d..f5472dc57f 100644
--- a/datafusion/core/src/physical_plan/sorts/merge.rs
+++ b/datafusion/core/src/physical_plan/sorts/merge.rs
@@ -39,13 +39,14 @@ macro_rules! primitive_merge_helper {
 }
 
 macro_rules! merge_helper {
-    ($t:ty, $sort:ident, $streams:ident, $schema:ident, $tracking_metrics:ident, $batch_size:ident) => {{
+    ($t:ty, $sort:ident, $streams:ident, $schema:ident, $tracking_metrics:ident, $batch_size:ident, $fetch:ident) => {{
         let streams = FieldCursorStream::<$t>::new($sort, $streams);
         return Ok(Box::pin(SortPreservingMergeStream::new(
             Box::new(streams),
             $schema,
             $tracking_metrics,
             $batch_size,
+            $fetch,
         )));
     }};
 }
@@ -57,17 +58,18 @@ pub(crate) fn streaming_merge(
     expressions: &[PhysicalSortExpr],
     metrics: BaselineMetrics,
     batch_size: usize,
+    fetch: Option<usize>,
 ) -> Result<SendableRecordBatchStream> {
     // Special case single column comparisons with optimized cursor implementations
     if expressions.len() == 1 {
         let sort = expressions[0].clone();
         let data_type = sort.expr.data_type(schema.as_ref())?;
         downcast_primitive! {
-            data_type => (primitive_merge_helper, sort, streams, schema, metrics, batch_size),
-            DataType::Utf8 => merge_helper!(StringArray, sort, streams, schema, metrics, batch_size)
-            DataType::LargeUtf8 => merge_helper!(LargeStringArray, sort, streams, schema, metrics, batch_size)
-            DataType::Binary => merge_helper!(BinaryArray, sort, streams, schema, metrics, batch_size)
-            DataType::LargeBinary => merge_helper!(LargeBinaryArray, sort, streams, schema, metrics, batch_size)
+            data_type => (primitive_merge_helper, sort, streams, schema, metrics, batch_size, fetch),
+            DataType::Utf8 => merge_helper!(StringArray, sort, streams, schema, metrics, batch_size, fetch)
+            DataType::LargeUtf8 => merge_helper!(LargeStringArray, sort, streams, schema, metrics, batch_size, fetch)
+            DataType::Binary => merge_helper!(BinaryArray, sort, streams, schema, metrics, batch_size, fetch)
+            DataType::LargeBinary => merge_helper!(LargeBinaryArray, sort, streams, schema, metrics, batch_size, fetch)
             _ => {}
         }
     }
@@ -78,6 +80,7 @@ pub(crate) fn streaming_merge(
         schema,
         metrics,
         batch_size,
+        fetch
     )))
 }
 
@@ -140,6 +143,12 @@ struct SortPreservingMergeStream<C> {
 
     /// Vector that holds cursors for each non-exhausted input partition
     cursors: Vec<Option<C>>,
+
+    /// Optional number of rows to fetch
+    fetch: Option<usize>,
+
+    /// number of rows produces
+    produced: usize,
 }
 
 impl<C: Cursor> SortPreservingMergeStream<C> {
@@ -148,6 +157,7 @@ impl<C: Cursor> SortPreservingMergeStream<C> {
         schema: SchemaRef,
         metrics: BaselineMetrics,
         batch_size: usize,
+        fetch: Option<usize>,
     ) -> Self {
         let stream_count = streams.partitions();
 
@@ -160,6 +170,8 @@ impl<C: Cursor> SortPreservingMergeStream<C> {
             loser_tree: vec![],
             loser_tree_adjusted: false,
             batch_size,
+            fetch,
+            produced: 0,
         }
     }
 
@@ -227,11 +239,19 @@ impl<C: Cursor> SortPreservingMergeStream<C> {
             if self.advance(stream_idx) {
                 self.loser_tree_adjusted = false;
                 self.in_progress.push_row(stream_idx);
-                if self.in_progress.len() < self.batch_size {
+
+                // stop sorting if fetch has been reached
+                if self.fetch.map(|fetch| self.produced + self.in_progress.len() >= fetch).unwrap_or(false) {
+                    self.aborted = true;
+                }
+                if self.in_progress.len() < self.batch_size  {
                     continue;
                 }
             }
 
+            self.produced += self.in_progress.len();
+
+
             return Poll::Ready(self.in_progress.build_record_batch().transpose());
         }
     }
diff --git a/datafusion/core/src/physical_plan/sorts/sort.rs b/datafusion/core/src/physical_plan/sorts/sort.rs
index 4983b0ea83..205ec706b5 100644
--- a/datafusion/core/src/physical_plan/sorts/sort.rs
+++ b/datafusion/core/src/physical_plan/sorts/sort.rs
@@ -189,6 +189,7 @@ impl ExternalSorter {
                 &self.expr,
                 self.metrics.baseline.clone(),
                 self.batch_size,
+                self.fetch,
             )
         } else if !self.in_mem_batches.is_empty() {
             let result = self.in_mem_sort_stream(self.metrics.baseline.clone());
@@ -285,14 +286,13 @@ impl ExternalSorter {
             })
             .collect::<Result<_>>()?;
 
-        // TODO: Pushdown fetch to streaming merge (#6000)
-
         streaming_merge(
             streams,
             self.schema.clone(),
             &self.expr,
             metrics,
             self.batch_size,
+            self.fetch,
         )
     }
 
diff --git a/datafusion/core/src/physical_plan/sorts/sort_preserving_merge.rs b/datafusion/core/src/physical_plan/sorts/sort_preserving_merge.rs
index 4db1fea2a4..724ec1de3c 100644
--- a/datafusion/core/src/physical_plan/sorts/sort_preserving_merge.rs
+++ b/datafusion/core/src/physical_plan/sorts/sort_preserving_merge.rs
@@ -71,6 +71,8 @@ pub struct SortPreservingMergeExec {
     expr: Vec<PhysicalSortExpr>,
     /// Execution metrics
     metrics: ExecutionPlanMetricsSet,
+    /// Optional number of rows to fetch
+    fetch: Option<usize>
 }
 
 impl SortPreservingMergeExec {
@@ -80,8 +82,14 @@ impl SortPreservingMergeExec {
             input,
             expr,
             metrics: ExecutionPlanMetricsSet::new(),
+            fetch: None,
         }
     }
+    /// Sets the number of rows to fetch 
+    pub fn with_fetch(mut self, fetch: Option<usize>) -> Self {
+        self.fetch = fetch;
+        self
+    }
 
     /// Input schema
     pub fn input(&self) -> &Arc<dyn ExecutionPlan> {
@@ -92,6 +100,12 @@ impl SortPreservingMergeExec {
     pub fn expr(&self) -> &[PhysicalSortExpr] {
         &self.expr
     }
+
+    /// Fetch
+    pub fn fetch(&self) -> Option<usize> {
+        self.fetch
+    }
+    
 }
 
 impl ExecutionPlan for SortPreservingMergeExec {
@@ -140,7 +154,7 @@ impl ExecutionPlan for SortPreservingMergeExec {
         Ok(Arc::new(SortPreservingMergeExec::new(
             self.expr.clone(),
             children[0].clone(),
-        )))
+        ).with_fetch(self.fetch)))
     }
 
     fn execute(
@@ -192,6 +206,7 @@ impl ExecutionPlan for SortPreservingMergeExec {
                     &self.expr,
                     BaselineMetrics::new(&self.metrics, partition),
                     context.session_config().batch_size(),
+                    self.fetch,
                 )?;
 
                 debug!("Got stream result from SortPreservingMergeStream::new_from_receivers");
@@ -209,7 +224,12 @@ impl ExecutionPlan for SortPreservingMergeExec {
         match t {
             DisplayFormatType::Default | DisplayFormatType::Verbose => {
                 let expr: Vec<String> = self.expr.iter().map(|e| e.to_string()).collect();
-                write!(f, "SortPreservingMergeExec: [{}]", expr.join(","))
+                write!(f, "SortPreservingMergeExec: [{}]", expr.join(","))?;
+                if let Some(fetch) = self.fetch {
+                    write!(f, "fetch={fetch}")?;
+                };
+
+                Ok(())
             }
         }
     }
@@ -814,6 +834,7 @@ mod tests {
             sort.as_slice(),
             BaselineMetrics::new(&metrics, 0),
             task_ctx.session_config().batch_size(),
+            None,
         )
         .unwrap();
 
diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto
index 4cc80c207c..4f0d324dc1 100644
--- a/datafusion/proto/proto/datafusion.proto
+++ b/datafusion/proto/proto/datafusion.proto
@@ -1365,6 +1365,8 @@ message SortExecNode {
 message SortPreservingMergeExecNode {
   PhysicalPlanNode input = 1;
   repeated PhysicalExprNode expr = 2;
+  // Maximum number of highest/lowest rows to fetch; negative means no limit
+  int64 fetch = 3;
 }
 
 message CoalesceBatchesExecNode {
diff --git a/datafusion/proto/proto/proto_descriptor.bin b/datafusion/proto/proto/proto_descriptor.bin
new file mode 100644
index 0000000000..448c1eda25
Binary files /dev/null and b/datafusion/proto/proto/proto_descriptor.bin differ
diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/datafusion.rs
similarity index 99%
copy from datafusion/proto/src/generated/prost.rs
copy to datafusion/proto/src/datafusion.rs
index 31086deead..ae5dfe14b5 100644
--- a/datafusion/proto/src/generated/prost.rs
+++ b/datafusion/proto/src/datafusion.rs
@@ -1923,6 +1923,9 @@ pub struct SortPreservingMergeExecNode {
     pub input: ::core::option::Option<::prost::alloc::boxed::Box<PhysicalPlanNode>>,
     #[prost(message, repeated, tag = "2")]
     pub expr: ::prost::alloc::vec::Vec<PhysicalExprNode>,
+    /// Maximum number of highest/lowest rows to fetch; negative means no limit
+    #[prost(int64, tag = "3")]
+    pub fetch: i64,
 }
 #[allow(clippy::derive_partial_eq_without_eq)]
 #[derive(Clone, PartialEq, ::prost::Message)]
diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/datafusion.serde.rs
similarity index 99%
copy from datafusion/proto/src/generated/pbjson.rs
copy to datafusion/proto/src/datafusion.serde.rs
index 42397e3da2..ab8ddf4f29 100644
--- a/datafusion/proto/src/generated/pbjson.rs
+++ b/datafusion/proto/src/datafusion.serde.rs
@@ -20269,6 +20269,9 @@ impl serde::Serialize for SortPreservingMergeExecNode {
         if !self.expr.is_empty() {
             len += 1;
         }
+        if self.fetch != 0 {
+            len += 1;
+        }
         let mut struct_ser = serializer.serialize_struct("datafusion.SortPreservingMergeExecNode", len)?;
         if let Some(v) = self.input.as_ref() {
             struct_ser.serialize_field("input", v)?;
@@ -20276,6 +20279,9 @@ impl serde::Serialize for SortPreservingMergeExecNode {
         if !self.expr.is_empty() {
             struct_ser.serialize_field("expr", &self.expr)?;
         }
+        if self.fetch != 0 {
+            struct_ser.serialize_field("fetch", ToString::to_string(&self.fetch).as_str())?;
+        }
         struct_ser.end()
     }
 }
@@ -20288,12 +20294,14 @@ impl<'de> serde::Deserialize<'de> for SortPreservingMergeExecNode {
         const FIELDS: &[&str] = &[
             "input",
             "expr",
+            "fetch",
         ];
 
         #[allow(clippy::enum_variant_names)]
         enum GeneratedField {
             Input,
             Expr,
+            Fetch,
         }
         impl<'de> serde::Deserialize<'de> for GeneratedField {
             fn deserialize<D>(deserializer: D) -> std::result::Result<GeneratedField, D::Error>
@@ -20317,6 +20325,7 @@ impl<'de> serde::Deserialize<'de> for SortPreservingMergeExecNode {
                         match value {
                             "input" => Ok(GeneratedField::Input),
                             "expr" => Ok(GeneratedField::Expr),
+                            "fetch" => Ok(GeneratedField::Fetch),
                             _ => Err(serde::de::Error::unknown_field(value, FIELDS)),
                         }
                     }
@@ -20338,6 +20347,7 @@ impl<'de> serde::Deserialize<'de> for SortPreservingMergeExecNode {
             {
                 let mut input__ = None;
                 let mut expr__ = None;
+                let mut fetch__ = None;
                 while let Some(k) = map.next_key()? {
                     match k {
                         GeneratedField::Input => {
@@ -20352,11 +20362,20 @@ impl<'de> serde::Deserialize<'de> for SortPreservingMergeExecNode {
                             }
                             expr__ = Some(map.next_value()?);
                         }
+                        GeneratedField::Fetch => {
+                            if fetch__.is_some() {
+                                return Err(serde::de::Error::duplicate_field("fetch"));
+                            }
+                            fetch__ = 
+                                Some(map.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0)
+                            ;
+                        }
                     }
                 }
                 Ok(SortPreservingMergeExecNode {
                     input: input__,
                     expr: expr__.unwrap_or_default(),
+                    fetch: fetch__.unwrap_or_default(),
                 })
             }
         }
diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs
index 42397e3da2..ab8ddf4f29 100644
--- a/datafusion/proto/src/generated/pbjson.rs
+++ b/datafusion/proto/src/generated/pbjson.rs
@@ -20269,6 +20269,9 @@ impl serde::Serialize for SortPreservingMergeExecNode {
         if !self.expr.is_empty() {
             len += 1;
         }
+        if self.fetch != 0 {
+            len += 1;
+        }
         let mut struct_ser = serializer.serialize_struct("datafusion.SortPreservingMergeExecNode", len)?;
         if let Some(v) = self.input.as_ref() {
             struct_ser.serialize_field("input", v)?;
@@ -20276,6 +20279,9 @@ impl serde::Serialize for SortPreservingMergeExecNode {
         if !self.expr.is_empty() {
             struct_ser.serialize_field("expr", &self.expr)?;
         }
+        if self.fetch != 0 {
+            struct_ser.serialize_field("fetch", ToString::to_string(&self.fetch).as_str())?;
+        }
         struct_ser.end()
     }
 }
@@ -20288,12 +20294,14 @@ impl<'de> serde::Deserialize<'de> for SortPreservingMergeExecNode {
         const FIELDS: &[&str] = &[
             "input",
             "expr",
+            "fetch",
         ];
 
         #[allow(clippy::enum_variant_names)]
         enum GeneratedField {
             Input,
             Expr,
+            Fetch,
         }
         impl<'de> serde::Deserialize<'de> for GeneratedField {
             fn deserialize<D>(deserializer: D) -> std::result::Result<GeneratedField, D::Error>
@@ -20317,6 +20325,7 @@ impl<'de> serde::Deserialize<'de> for SortPreservingMergeExecNode {
                         match value {
                             "input" => Ok(GeneratedField::Input),
                             "expr" => Ok(GeneratedField::Expr),
+                            "fetch" => Ok(GeneratedField::Fetch),
                             _ => Err(serde::de::Error::unknown_field(value, FIELDS)),
                         }
                     }
@@ -20338,6 +20347,7 @@ impl<'de> serde::Deserialize<'de> for SortPreservingMergeExecNode {
             {
                 let mut input__ = None;
                 let mut expr__ = None;
+                let mut fetch__ = None;
                 while let Some(k) = map.next_key()? {
                     match k {
                         GeneratedField::Input => {
@@ -20352,11 +20362,20 @@ impl<'de> serde::Deserialize<'de> for SortPreservingMergeExecNode {
                             }
                             expr__ = Some(map.next_value()?);
                         }
+                        GeneratedField::Fetch => {
+                            if fetch__.is_some() {
+                                return Err(serde::de::Error::duplicate_field("fetch"));
+                            }
+                            fetch__ = 
+                                Some(map.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0)
+                            ;
+                        }
                     }
                 }
                 Ok(SortPreservingMergeExecNode {
                     input: input__,
                     expr: expr__.unwrap_or_default(),
+                    fetch: fetch__.unwrap_or_default(),
                 })
             }
         }
diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs
index 31086deead..ae5dfe14b5 100644
--- a/datafusion/proto/src/generated/prost.rs
+++ b/datafusion/proto/src/generated/prost.rs
@@ -1923,6 +1923,9 @@ pub struct SortPreservingMergeExecNode {
     pub input: ::core::option::Option<::prost::alloc::boxed::Box<PhysicalPlanNode>>,
     #[prost(message, repeated, tag = "2")]
     pub expr: ::prost::alloc::vec::Vec<PhysicalExprNode>,
+    /// Maximum number of highest/lowest rows to fetch; negative means no limit
+    #[prost(int64, tag = "3")]
+    pub fetch: i64,
 }
 #[allow(clippy::derive_partial_eq_without_eq)]
 #[derive(Clone, PartialEq, ::prost::Message)]
diff --git a/datafusion/proto/src/physical_plan/mod.rs b/datafusion/proto/src/physical_plan/mod.rs
index 1daa1c2e4b..ba1e9e0b09 100644
--- a/datafusion/proto/src/physical_plan/mod.rs
+++ b/datafusion/proto/src/physical_plan/mod.rs
@@ -1144,6 +1144,7 @@ impl AsExecutionPlan for PhysicalPlanNode {
                     Box::new(protobuf::SortPreservingMergeExecNode {
                         input: Some(Box::new(input)),
                         expr,
+                        fetch: exec.fetch().map(|f|f as i64).unwrap_or(-1)
                     }),
                 )),
             })