You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by dh...@apache.org on 2023/06/30 14:28:52 UTC
[arrow-datafusion] 01/02: Add fetch to sortpreservingmergeexec
This is an automated email from the ASF dual-hosted git repository.
dheres pushed a commit to branch add_fetch_to_sort_preserving_merge_exec
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
commit d519a8939e564cb2740b93f08e3a546da1c4fdcd
Author: Daniƫl Heres <da...@coralogix.com>
AuthorDate: Fri Jun 30 16:28:12 2023 +0200
Add fetch to sortpreservingmergeexec
---
.../physical_optimizer/global_sort_selection.rs | 2 +-
.../src/physical_optimizer/sort_enforcement.rs | 21 +++++++++++--
.../core/src/physical_plan/repartition/mod.rs | 1 +
datafusion/core/src/physical_plan/sorts/merge.rs | 34 ++++++++++++++++-----
datafusion/core/src/physical_plan/sorts/sort.rs | 4 +--
.../physical_plan/sorts/sort_preserving_merge.rs | 25 +++++++++++++--
datafusion/proto/proto/datafusion.proto | 2 ++
datafusion/proto/proto/proto_descriptor.bin | Bin 0 -> 86986 bytes
.../src/{generated/prost.rs => datafusion.rs} | 3 ++
.../{generated/pbjson.rs => datafusion.serde.rs} | 19 ++++++++++++
datafusion/proto/src/generated/pbjson.rs | 19 ++++++++++++
datafusion/proto/src/generated/prost.rs | 3 ++
datafusion/proto/src/physical_plan/mod.rs | 1 +
13 files changed, 120 insertions(+), 14 deletions(-)
diff --git a/datafusion/core/src/physical_optimizer/global_sort_selection.rs b/datafusion/core/src/physical_optimizer/global_sort_selection.rs
index 9466297d24..0b9054f89f 100644
--- a/datafusion/core/src/physical_optimizer/global_sort_selection.rs
+++ b/datafusion/core/src/physical_optimizer/global_sort_selection.rs
@@ -70,7 +70,7 @@ impl PhysicalOptimizerRule for GlobalSortSelection {
Arc::new(SortPreservingMergeExec::new(
sort_exec.expr().to_vec(),
Arc::new(sort),
- ));
+ ).with_fetch(sort_exec.fetch()));
Some(global_sort)
} else {
None
diff --git a/datafusion/core/src/physical_optimizer/sort_enforcement.rs b/datafusion/core/src/physical_optimizer/sort_enforcement.rs
index 719c152841..f10401877f 100644
--- a/datafusion/core/src/physical_optimizer/sort_enforcement.rs
+++ b/datafusion/core/src/physical_optimizer/sort_enforcement.rs
@@ -422,7 +422,8 @@ fn parallelize_sorts(
update_child_to_remove_coalesce(&mut prev_layer, &mut coalesce_onwards[0])?;
let sort_exprs = get_sort_exprs(&plan)?;
add_sort_above(&mut prev_layer, sort_exprs.to_vec())?;
- let spm = SortPreservingMergeExec::new(sort_exprs.to_vec(), prev_layer);
+ let sort_fetch = get_sort_fetch(&plan)?;
+ let spm = SortPreservingMergeExec::new(sort_exprs.to_vec(), prev_layer).with_fetch(sort_fetch);
return Ok(Transformed::Yes(PlanWithCorrespondingCoalescePartitions {
plan: Arc::new(spm),
coalesce_onwards: vec![None],
@@ -785,7 +786,7 @@ fn remove_corresponding_sort_from_sub_plan(
Ok(updated_plan)
}
-/// Converts an [ExecutionPlan] trait object to a [PhysicalSortExpr] slice when possible.
+/// Retrieves a fetch from a `SortExec` when possible
fn get_sort_exprs(sort_any: &Arc<dyn ExecutionPlan>) -> Result<&[PhysicalSortExpr]> {
if let Some(sort_exec) = sort_any.as_any().downcast_ref::<SortExec>() {
Ok(sort_exec.expr())
@@ -801,6 +802,22 @@ fn get_sort_exprs(sort_any: &Arc<dyn ExecutionPlan>) -> Result<&[PhysicalSortExp
}
}
+/// gets
+fn get_sort_fetch(sort_any: &Arc<dyn ExecutionPlan>) -> Result<Option<usize>> {
+ if let Some(sort_exec) = sort_any.as_any().downcast_ref::<SortExec>() {
+ Ok(sort_exec.fetch())
+ } else if let Some(sort_preserving_merge_exec) =
+ sort_any.as_any().downcast_ref::<SortPreservingMergeExec>()
+ {
+ Ok(sort_preserving_merge_exec.fetch())
+ } else {
+ Err(DataFusionError::Plan(
+ "Given ExecutionPlan is not a SortExec or a SortPreservingMergeExec"
+ .to_string(),
+ ))
+ }
+}
+
/// Compares physical ordering (output ordering of input executor) with
/// `partitionby_exprs` and `orderby_keys`
/// to decide whether existing ordering is sufficient to run current window executor.
diff --git a/datafusion/core/src/physical_plan/repartition/mod.rs b/datafusion/core/src/physical_plan/repartition/mod.rs
index 72ff0c3713..85225eb471 100644
--- a/datafusion/core/src/physical_plan/repartition/mod.rs
+++ b/datafusion/core/src/physical_plan/repartition/mod.rs
@@ -497,6 +497,7 @@ impl ExecutionPlan for RepartitionExec {
sort_exprs,
BaselineMetrics::new(&self.metrics, partition),
context.session_config().batch_size(),
+ None,
)
} else {
Ok(Box::pin(RepartitionStream {
diff --git a/datafusion/core/src/physical_plan/sorts/merge.rs b/datafusion/core/src/physical_plan/sorts/merge.rs
index d8a3cdef4d..f5472dc57f 100644
--- a/datafusion/core/src/physical_plan/sorts/merge.rs
+++ b/datafusion/core/src/physical_plan/sorts/merge.rs
@@ -39,13 +39,14 @@ macro_rules! primitive_merge_helper {
}
macro_rules! merge_helper {
- ($t:ty, $sort:ident, $streams:ident, $schema:ident, $tracking_metrics:ident, $batch_size:ident) => {{
+ ($t:ty, $sort:ident, $streams:ident, $schema:ident, $tracking_metrics:ident, $batch_size:ident, $fetch:ident) => {{
let streams = FieldCursorStream::<$t>::new($sort, $streams);
return Ok(Box::pin(SortPreservingMergeStream::new(
Box::new(streams),
$schema,
$tracking_metrics,
$batch_size,
+ $fetch,
)));
}};
}
@@ -57,17 +58,18 @@ pub(crate) fn streaming_merge(
expressions: &[PhysicalSortExpr],
metrics: BaselineMetrics,
batch_size: usize,
+ fetch: Option<usize>,
) -> Result<SendableRecordBatchStream> {
// Special case single column comparisons with optimized cursor implementations
if expressions.len() == 1 {
let sort = expressions[0].clone();
let data_type = sort.expr.data_type(schema.as_ref())?;
downcast_primitive! {
- data_type => (primitive_merge_helper, sort, streams, schema, metrics, batch_size),
- DataType::Utf8 => merge_helper!(StringArray, sort, streams, schema, metrics, batch_size)
- DataType::LargeUtf8 => merge_helper!(LargeStringArray, sort, streams, schema, metrics, batch_size)
- DataType::Binary => merge_helper!(BinaryArray, sort, streams, schema, metrics, batch_size)
- DataType::LargeBinary => merge_helper!(LargeBinaryArray, sort, streams, schema, metrics, batch_size)
+ data_type => (primitive_merge_helper, sort, streams, schema, metrics, batch_size, fetch),
+ DataType::Utf8 => merge_helper!(StringArray, sort, streams, schema, metrics, batch_size, fetch)
+ DataType::LargeUtf8 => merge_helper!(LargeStringArray, sort, streams, schema, metrics, batch_size, fetch)
+ DataType::Binary => merge_helper!(BinaryArray, sort, streams, schema, metrics, batch_size, fetch)
+ DataType::LargeBinary => merge_helper!(LargeBinaryArray, sort, streams, schema, metrics, batch_size, fetch)
_ => {}
}
}
@@ -78,6 +80,7 @@ pub(crate) fn streaming_merge(
schema,
metrics,
batch_size,
+ fetch
)))
}
@@ -140,6 +143,12 @@ struct SortPreservingMergeStream<C> {
/// Vector that holds cursors for each non-exhausted input partition
cursors: Vec<Option<C>>,
+
+ /// Optional number of rows to fetch
+ fetch: Option<usize>,
+
+ /// number of rows produces
+ produced: usize,
}
impl<C: Cursor> SortPreservingMergeStream<C> {
@@ -148,6 +157,7 @@ impl<C: Cursor> SortPreservingMergeStream<C> {
schema: SchemaRef,
metrics: BaselineMetrics,
batch_size: usize,
+ fetch: Option<usize>,
) -> Self {
let stream_count = streams.partitions();
@@ -160,6 +170,8 @@ impl<C: Cursor> SortPreservingMergeStream<C> {
loser_tree: vec![],
loser_tree_adjusted: false,
batch_size,
+ fetch,
+ produced: 0,
}
}
@@ -227,11 +239,19 @@ impl<C: Cursor> SortPreservingMergeStream<C> {
if self.advance(stream_idx) {
self.loser_tree_adjusted = false;
self.in_progress.push_row(stream_idx);
- if self.in_progress.len() < self.batch_size {
+
+ // stop sorting if fetch has been reached
+ if self.fetch.map(|fetch| self.produced + self.in_progress.len() >= fetch).unwrap_or(false) {
+ self.aborted = true;
+ }
+ if self.in_progress.len() < self.batch_size {
continue;
}
}
+ self.produced += self.in_progress.len();
+
+
return Poll::Ready(self.in_progress.build_record_batch().transpose());
}
}
diff --git a/datafusion/core/src/physical_plan/sorts/sort.rs b/datafusion/core/src/physical_plan/sorts/sort.rs
index 4983b0ea83..205ec706b5 100644
--- a/datafusion/core/src/physical_plan/sorts/sort.rs
+++ b/datafusion/core/src/physical_plan/sorts/sort.rs
@@ -189,6 +189,7 @@ impl ExternalSorter {
&self.expr,
self.metrics.baseline.clone(),
self.batch_size,
+ self.fetch,
)
} else if !self.in_mem_batches.is_empty() {
let result = self.in_mem_sort_stream(self.metrics.baseline.clone());
@@ -285,14 +286,13 @@ impl ExternalSorter {
})
.collect::<Result<_>>()?;
- // TODO: Pushdown fetch to streaming merge (#6000)
-
streaming_merge(
streams,
self.schema.clone(),
&self.expr,
metrics,
self.batch_size,
+ self.fetch,
)
}
diff --git a/datafusion/core/src/physical_plan/sorts/sort_preserving_merge.rs b/datafusion/core/src/physical_plan/sorts/sort_preserving_merge.rs
index 4db1fea2a4..724ec1de3c 100644
--- a/datafusion/core/src/physical_plan/sorts/sort_preserving_merge.rs
+++ b/datafusion/core/src/physical_plan/sorts/sort_preserving_merge.rs
@@ -71,6 +71,8 @@ pub struct SortPreservingMergeExec {
expr: Vec<PhysicalSortExpr>,
/// Execution metrics
metrics: ExecutionPlanMetricsSet,
+ /// Optional number of rows to fetch
+ fetch: Option<usize>
}
impl SortPreservingMergeExec {
@@ -80,8 +82,14 @@ impl SortPreservingMergeExec {
input,
expr,
metrics: ExecutionPlanMetricsSet::new(),
+ fetch: None,
}
}
+ /// Sets the number of rows to fetch
+ pub fn with_fetch(mut self, fetch: Option<usize>) -> Self {
+ self.fetch = fetch;
+ self
+ }
/// Input schema
pub fn input(&self) -> &Arc<dyn ExecutionPlan> {
@@ -92,6 +100,12 @@ impl SortPreservingMergeExec {
pub fn expr(&self) -> &[PhysicalSortExpr] {
&self.expr
}
+
+ /// Fetch
+ pub fn fetch(&self) -> Option<usize> {
+ self.fetch
+ }
+
}
impl ExecutionPlan for SortPreservingMergeExec {
@@ -140,7 +154,7 @@ impl ExecutionPlan for SortPreservingMergeExec {
Ok(Arc::new(SortPreservingMergeExec::new(
self.expr.clone(),
children[0].clone(),
- )))
+ ).with_fetch(self.fetch)))
}
fn execute(
@@ -192,6 +206,7 @@ impl ExecutionPlan for SortPreservingMergeExec {
&self.expr,
BaselineMetrics::new(&self.metrics, partition),
context.session_config().batch_size(),
+ self.fetch,
)?;
debug!("Got stream result from SortPreservingMergeStream::new_from_receivers");
@@ -209,7 +224,12 @@ impl ExecutionPlan for SortPreservingMergeExec {
match t {
DisplayFormatType::Default | DisplayFormatType::Verbose => {
let expr: Vec<String> = self.expr.iter().map(|e| e.to_string()).collect();
- write!(f, "SortPreservingMergeExec: [{}]", expr.join(","))
+ write!(f, "SortPreservingMergeExec: [{}]", expr.join(","))?;
+ if let Some(fetch) = self.fetch {
+ write!(f, "fetch={fetch}")?;
+ };
+
+ Ok(())
}
}
}
@@ -814,6 +834,7 @@ mod tests {
sort.as_slice(),
BaselineMetrics::new(&metrics, 0),
task_ctx.session_config().batch_size(),
+ None,
)
.unwrap();
diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto
index 4cc80c207c..4f0d324dc1 100644
--- a/datafusion/proto/proto/datafusion.proto
+++ b/datafusion/proto/proto/datafusion.proto
@@ -1365,6 +1365,8 @@ message SortExecNode {
message SortPreservingMergeExecNode {
PhysicalPlanNode input = 1;
repeated PhysicalExprNode expr = 2;
+ // Maximum number of highest/lowest rows to fetch; negative means no limit
+ int64 fetch = 3;
}
message CoalesceBatchesExecNode {
diff --git a/datafusion/proto/proto/proto_descriptor.bin b/datafusion/proto/proto/proto_descriptor.bin
new file mode 100644
index 0000000000..448c1eda25
Binary files /dev/null and b/datafusion/proto/proto/proto_descriptor.bin differ
diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/datafusion.rs
similarity index 99%
copy from datafusion/proto/src/generated/prost.rs
copy to datafusion/proto/src/datafusion.rs
index 31086deead..ae5dfe14b5 100644
--- a/datafusion/proto/src/generated/prost.rs
+++ b/datafusion/proto/src/datafusion.rs
@@ -1923,6 +1923,9 @@ pub struct SortPreservingMergeExecNode {
pub input: ::core::option::Option<::prost::alloc::boxed::Box<PhysicalPlanNode>>,
#[prost(message, repeated, tag = "2")]
pub expr: ::prost::alloc::vec::Vec<PhysicalExprNode>,
+ /// Maximum number of highest/lowest rows to fetch; negative means no limit
+ #[prost(int64, tag = "3")]
+ pub fetch: i64,
}
#[allow(clippy::derive_partial_eq_without_eq)]
#[derive(Clone, PartialEq, ::prost::Message)]
diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/datafusion.serde.rs
similarity index 99%
copy from datafusion/proto/src/generated/pbjson.rs
copy to datafusion/proto/src/datafusion.serde.rs
index 42397e3da2..ab8ddf4f29 100644
--- a/datafusion/proto/src/generated/pbjson.rs
+++ b/datafusion/proto/src/datafusion.serde.rs
@@ -20269,6 +20269,9 @@ impl serde::Serialize for SortPreservingMergeExecNode {
if !self.expr.is_empty() {
len += 1;
}
+ if self.fetch != 0 {
+ len += 1;
+ }
let mut struct_ser = serializer.serialize_struct("datafusion.SortPreservingMergeExecNode", len)?;
if let Some(v) = self.input.as_ref() {
struct_ser.serialize_field("input", v)?;
@@ -20276,6 +20279,9 @@ impl serde::Serialize for SortPreservingMergeExecNode {
if !self.expr.is_empty() {
struct_ser.serialize_field("expr", &self.expr)?;
}
+ if self.fetch != 0 {
+ struct_ser.serialize_field("fetch", ToString::to_string(&self.fetch).as_str())?;
+ }
struct_ser.end()
}
}
@@ -20288,12 +20294,14 @@ impl<'de> serde::Deserialize<'de> for SortPreservingMergeExecNode {
const FIELDS: &[&str] = &[
"input",
"expr",
+ "fetch",
];
#[allow(clippy::enum_variant_names)]
enum GeneratedField {
Input,
Expr,
+ Fetch,
}
impl<'de> serde::Deserialize<'de> for GeneratedField {
fn deserialize<D>(deserializer: D) -> std::result::Result<GeneratedField, D::Error>
@@ -20317,6 +20325,7 @@ impl<'de> serde::Deserialize<'de> for SortPreservingMergeExecNode {
match value {
"input" => Ok(GeneratedField::Input),
"expr" => Ok(GeneratedField::Expr),
+ "fetch" => Ok(GeneratedField::Fetch),
_ => Err(serde::de::Error::unknown_field(value, FIELDS)),
}
}
@@ -20338,6 +20347,7 @@ impl<'de> serde::Deserialize<'de> for SortPreservingMergeExecNode {
{
let mut input__ = None;
let mut expr__ = None;
+ let mut fetch__ = None;
while let Some(k) = map.next_key()? {
match k {
GeneratedField::Input => {
@@ -20352,11 +20362,20 @@ impl<'de> serde::Deserialize<'de> for SortPreservingMergeExecNode {
}
expr__ = Some(map.next_value()?);
}
+ GeneratedField::Fetch => {
+ if fetch__.is_some() {
+ return Err(serde::de::Error::duplicate_field("fetch"));
+ }
+ fetch__ =
+ Some(map.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0)
+ ;
+ }
}
}
Ok(SortPreservingMergeExecNode {
input: input__,
expr: expr__.unwrap_or_default(),
+ fetch: fetch__.unwrap_or_default(),
})
}
}
diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs
index 42397e3da2..ab8ddf4f29 100644
--- a/datafusion/proto/src/generated/pbjson.rs
+++ b/datafusion/proto/src/generated/pbjson.rs
@@ -20269,6 +20269,9 @@ impl serde::Serialize for SortPreservingMergeExecNode {
if !self.expr.is_empty() {
len += 1;
}
+ if self.fetch != 0 {
+ len += 1;
+ }
let mut struct_ser = serializer.serialize_struct("datafusion.SortPreservingMergeExecNode", len)?;
if let Some(v) = self.input.as_ref() {
struct_ser.serialize_field("input", v)?;
@@ -20276,6 +20279,9 @@ impl serde::Serialize for SortPreservingMergeExecNode {
if !self.expr.is_empty() {
struct_ser.serialize_field("expr", &self.expr)?;
}
+ if self.fetch != 0 {
+ struct_ser.serialize_field("fetch", ToString::to_string(&self.fetch).as_str())?;
+ }
struct_ser.end()
}
}
@@ -20288,12 +20294,14 @@ impl<'de> serde::Deserialize<'de> for SortPreservingMergeExecNode {
const FIELDS: &[&str] = &[
"input",
"expr",
+ "fetch",
];
#[allow(clippy::enum_variant_names)]
enum GeneratedField {
Input,
Expr,
+ Fetch,
}
impl<'de> serde::Deserialize<'de> for GeneratedField {
fn deserialize<D>(deserializer: D) -> std::result::Result<GeneratedField, D::Error>
@@ -20317,6 +20325,7 @@ impl<'de> serde::Deserialize<'de> for SortPreservingMergeExecNode {
match value {
"input" => Ok(GeneratedField::Input),
"expr" => Ok(GeneratedField::Expr),
+ "fetch" => Ok(GeneratedField::Fetch),
_ => Err(serde::de::Error::unknown_field(value, FIELDS)),
}
}
@@ -20338,6 +20347,7 @@ impl<'de> serde::Deserialize<'de> for SortPreservingMergeExecNode {
{
let mut input__ = None;
let mut expr__ = None;
+ let mut fetch__ = None;
while let Some(k) = map.next_key()? {
match k {
GeneratedField::Input => {
@@ -20352,11 +20362,20 @@ impl<'de> serde::Deserialize<'de> for SortPreservingMergeExecNode {
}
expr__ = Some(map.next_value()?);
}
+ GeneratedField::Fetch => {
+ if fetch__.is_some() {
+ return Err(serde::de::Error::duplicate_field("fetch"));
+ }
+ fetch__ =
+ Some(map.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0)
+ ;
+ }
}
}
Ok(SortPreservingMergeExecNode {
input: input__,
expr: expr__.unwrap_or_default(),
+ fetch: fetch__.unwrap_or_default(),
})
}
}
diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs
index 31086deead..ae5dfe14b5 100644
--- a/datafusion/proto/src/generated/prost.rs
+++ b/datafusion/proto/src/generated/prost.rs
@@ -1923,6 +1923,9 @@ pub struct SortPreservingMergeExecNode {
pub input: ::core::option::Option<::prost::alloc::boxed::Box<PhysicalPlanNode>>,
#[prost(message, repeated, tag = "2")]
pub expr: ::prost::alloc::vec::Vec<PhysicalExprNode>,
+ /// Maximum number of highest/lowest rows to fetch; negative means no limit
+ #[prost(int64, tag = "3")]
+ pub fetch: i64,
}
#[allow(clippy::derive_partial_eq_without_eq)]
#[derive(Clone, PartialEq, ::prost::Message)]
diff --git a/datafusion/proto/src/physical_plan/mod.rs b/datafusion/proto/src/physical_plan/mod.rs
index 1daa1c2e4b..ba1e9e0b09 100644
--- a/datafusion/proto/src/physical_plan/mod.rs
+++ b/datafusion/proto/src/physical_plan/mod.rs
@@ -1144,6 +1144,7 @@ impl AsExecutionPlan for PhysicalPlanNode {
Box::new(protobuf::SortPreservingMergeExecNode {
input: Some(Box::new(input)),
expr,
+ fetch: exec.fetch().map(|f|f as i64).unwrap_or(-1)
}),
)),
})