You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by al...@apache.org on 2023/01/26 01:04:25 UTC

[arrow-datafusion] branch master updated: Simplify the `PushDownLimit`. (#5021)

This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/master by this push:
     new a4ebd395c Simplify the  `PushDownLimit`. (#5021)
a4ebd395c is described below

commit a4ebd395cb30588aab9cf49b43da82b04d2ac70d
Author: Remzi Yang <59...@users.noreply.github.com>
AuthorDate: Thu Jan 26 09:04:20 2023 +0800

    Simplify the  `PushDownLimit`. (#5021)
    
    * rewrite_merge_limit
    
    Signed-off-by: remzi <13...@gmail.com>
    
    * rewrite join
    
    Signed-off-by: remzi <13...@gmail.com>
    
    * address comments
    
    Signed-off-by: remzi <13...@gmail.com>
    
    Signed-off-by: remzi <13...@gmail.com>
---
 datafusion/optimizer/src/push_down_limit.rs | 249 +++++++++++++++-------------
 1 file changed, 138 insertions(+), 111 deletions(-)

diff --git a/datafusion/optimizer/src/push_down_limit.rs b/datafusion/optimizer/src/push_down_limit.rs
index 261909183..9fbc61fc5 100644
--- a/datafusion/optimizer/src/push_down_limit.rs
+++ b/datafusion/optimizer/src/push_down_limit.rs
@@ -37,43 +37,6 @@ impl PushDownLimit {
     }
 }
 
-fn is_no_join_condition(join: &Join) -> bool {
-    join.on.is_empty() && join.filter.is_none()
-}
-
-fn push_down_join(
-    join: &Join,
-    left_limit: Option<usize>,
-    right_limit: Option<usize>,
-) -> LogicalPlan {
-    let left = match left_limit {
-        Some(limit) => LogicalPlan::Limit(Limit {
-            skip: 0,
-            fetch: Some(limit),
-            input: Arc::new((*join.left).clone()),
-        }),
-        None => (*join.left).clone(),
-    };
-    let right = match right_limit {
-        Some(limit) => LogicalPlan::Limit(Limit {
-            skip: 0,
-            fetch: Some(limit),
-            input: Arc::new((*join.right).clone()),
-        }),
-        None => (*join.right).clone(),
-    };
-    LogicalPlan::Join(Join {
-        left: Arc::new(left),
-        right: Arc::new(right),
-        on: join.on.clone(),
-        filter: join.filter.clone(),
-        join_type: join.join_type,
-        join_constraint: join.join_constraint,
-        schema: join.schema.clone(),
-        null_equals_null: join.null_equals_null,
-    })
-}
-
 /// Push down Limit.
 impl OptimizerRule for PushDownLimit {
     fn try_optimize(
@@ -81,39 +44,64 @@ impl OptimizerRule for PushDownLimit {
         plan: &LogicalPlan,
         _config: &dyn OptimizerConfig,
     ) -> Result<Option<LogicalPlan>> {
+        use std::cmp::min;
+
         let limit = match plan {
             LogicalPlan::Limit(limit) => limit,
             _ => return Ok(None),
         };
 
-        if let LogicalPlan::Limit(child_limit) = &*limit.input {
+        if let LogicalPlan::Limit(child) = &*limit.input {
+            // Merge the Parent Limit and the Child Limit.
+
+            //  Case 0: Parent and Child are disjoint. (child_fetch <= skip)
+            //   Before merging:
+            //                     |........skip........|---fetch-->|              Parent Limit
+            //    |...child_skip...|---child_fetch-->|                             Child Limit
+            //   After merging:
+            //    |.........(child_skip + skip).........|
+            //   Before merging:
+            //                     |...skip...|------------fetch------------>|     Parent Limit
+            //    |...child_skip...|-------------child_fetch------------>|         Child Limit
+            //   After merging:
+            //    |....(child_skip + skip)....|---(child_fetch - skip)-->|
+
+            //  Case 1: Parent is beyond the range of Child. (skip < child_fetch <= skip + fetch)
+            //   Before merging:
+            //                     |...skip...|------------fetch------------>|     Parent Limit
+            //    |...child_skip...|-------------child_fetch------------>|         Child Limit
+            //   After merging:
+            //    |....(child_skip + skip)....|---(child_fetch - skip)-->|
+
+            //  Case 2: Parent is in the range of Child. (skip + fetch < child_fetch)
+            //   Before merging:
+            //                     |...skip...|---fetch-->|                        Parent Limit
+            //    |...child_skip...|-------------child_fetch------------>|         Child Limit
+            //   After merging:
+            //    |....(child_skip + skip)....|---fetch-->|
             let parent_skip = limit.skip;
-            let parent_fetch = limit.fetch;
-
-            // Merge limit
-            // Parent range [child_skip + skip, child_skip + skip + fetch)
-            // Child range [child_skip, child_skip + child_fetch)
-            // Merge -> [child_skip + skip, min(child_skip + skip + fetch, child_skip + child_fetch) )
-            // Merge LimitPlan -> [child_skip + skip, min(fetch, child_fetch - skip) )
-            let new_fetch = match parent_fetch {
-                Some(fetch) => match child_limit.fetch {
-                    Some(child_fetch) => Some(std::cmp::min(
-                        fetch,
-                        fetch_minus_skip(child_fetch, parent_skip),
-                    )),
-                    None => Some(fetch),
-                },
-                _ => child_limit
-                    .fetch
-                    .map(|child_fetch| fetch_minus_skip(child_fetch, parent_skip)),
+            let new_fetch = match (limit.fetch, child.fetch) {
+                (Some(fetch), Some(child_fetch)) => {
+                    Some(min(fetch, child_fetch.saturating_sub(parent_skip)))
+                }
+                (Some(fetch), None) => Some(fetch),
+                (None, Some(child_fetch)) => {
+                    Some(child_fetch.saturating_sub(parent_skip))
+                }
+                (None, None) => None,
             };
 
             let plan = LogicalPlan::Limit(Limit {
-                skip: child_limit.skip + limit.skip,
+                skip: child.skip + parent_skip,
                 fetch: new_fetch,
-                input: Arc::new((*child_limit.input).clone()),
+                input: Arc::new((*child.input).clone()),
             });
-            return self.try_optimize(&plan, _config);
+            return {
+                match self.try_optimize(&plan, _config)? {
+                    Some(new_plan) => Ok(Some(new_plan)),
+                    None => Ok(Some(plan)),
+                }
+            };
         }
 
         let fetch = match limit.fetch {
@@ -121,20 +109,25 @@ impl OptimizerRule for PushDownLimit {
             None => return Ok(None),
         };
         let skip = limit.skip;
-
         let child_plan = &*limit.input;
+
         let plan = match child_plan {
             LogicalPlan::TableScan(scan) => {
                 let limit = if fetch != 0 { fetch + skip } else { 0 };
-                let new_input = LogicalPlan::TableScan(TableScan {
-                    table_name: scan.table_name.clone(),
-                    source: scan.source.clone(),
-                    projection: scan.projection.clone(),
-                    filters: scan.filters.clone(),
-                    fetch: scan.fetch.map(|x| std::cmp::min(x, limit)).or(Some(limit)),
-                    projected_schema: scan.projected_schema.clone(),
-                });
-                plan.with_new_inputs(&[new_input])?
+                let new_fetch = scan.fetch.map(|x| min(x, limit)).or(Some(limit));
+                if new_fetch == scan.fetch {
+                    None
+                } else {
+                    let new_input = LogicalPlan::TableScan(TableScan {
+                        table_name: scan.table_name.clone(),
+                        source: scan.source.clone(),
+                        projection: scan.projection.clone(),
+                        filters: scan.filters.clone(),
+                        fetch: scan.fetch.map(|x| min(x, limit)).or(Some(limit)),
+                        projected_schema: scan.projected_schema.clone(),
+                    });
+                    Some(plan.with_new_inputs(&[new_input])?)
+                }
             }
             LogicalPlan::Union(union) => {
                 let new_inputs = union
@@ -152,7 +145,7 @@ impl OptimizerRule for PushDownLimit {
                     inputs: new_inputs,
                     schema: union.schema.clone(),
                 });
-                plan.with_new_inputs(&[union])?
+                Some(plan.with_new_inputs(&[union])?)
             }
 
             LogicalPlan::CrossJoin(cross_join) => {
@@ -173,47 +166,34 @@ impl OptimizerRule for PushDownLimit {
                     right: Arc::new(new_right),
                     schema: plan.schema().clone(),
                 });
-                plan.with_new_inputs(&[new_cross_join])?
+                Some(plan.with_new_inputs(&[new_cross_join])?)
             }
 
             LogicalPlan::Join(join) => {
-                let limit = fetch + skip;
-                let new_join = match join.join_type {
-                    JoinType::Left | JoinType::Right | JoinType::Full
-                        if is_no_join_condition(join) =>
-                    {
-                        // push left and right
-                        push_down_join(join, Some(limit), Some(limit))
+                let new_join = push_down_join(join, fetch + skip);
+                match new_join {
+                    Some(new_join) => {
+                        Some(plan.with_new_inputs(&[LogicalPlan::Join(new_join)])?)
                     }
-                    JoinType::LeftSemi | JoinType::LeftAnti
-                        if is_no_join_condition(join) =>
-                    {
-                        // push left
-                        push_down_join(join, Some(limit), None)
-                    }
-                    JoinType::RightSemi | JoinType::RightAnti
-                        if is_no_join_condition(join) =>
-                    {
-                        // push right
-                        push_down_join(join, None, Some(limit))
-                    }
-                    JoinType::Left => push_down_join(join, Some(limit), None),
-                    JoinType::Right => push_down_join(join, None, Some(limit)),
-                    _ => push_down_join(join, None, None),
-                };
-                plan.with_new_inputs(&[new_join])?
+                    None => None,
+                }
             }
 
             LogicalPlan::Sort(sort) => {
-                let sort_fetch = skip + fetch;
-                let new_sort = LogicalPlan::Sort(Sort {
-                    expr: sort.expr.clone(),
-                    input: Arc::new((*sort.input).clone()),
-                    fetch: Some(
-                        sort.fetch.map(|f| f.min(sort_fetch)).unwrap_or(sort_fetch),
-                    ),
-                });
-                plan.with_new_inputs(&[new_sort])?
+                let new_fetch = {
+                    let sort_fetch = skip + fetch;
+                    Some(sort.fetch.map(|f| f.min(sort_fetch)).unwrap_or(sort_fetch))
+                };
+                if new_fetch == sort.fetch {
+                    None
+                } else {
+                    let new_sort = LogicalPlan::Sort(Sort {
+                        expr: sort.expr.clone(),
+                        input: Arc::new((*sort.input).clone()),
+                        fetch: new_fetch,
+                    });
+                    Some(plan.with_new_inputs(&[new_sort])?)
+                }
             }
             LogicalPlan::Projection(_) | LogicalPlan::SubqueryAlias(_) => {
                 // commute
@@ -221,12 +201,12 @@ impl OptimizerRule for PushDownLimit {
                     plan.with_new_inputs(&[
                         (*(child_plan.inputs().get(0).unwrap())).clone()
                     ])?;
-                child_plan.with_new_inputs(&[new_limit])?
+                Some(child_plan.with_new_inputs(&[new_limit])?)
             }
-            _ => plan.clone(),
+            _ => None,
         };
 
-        Ok(Some(plan))
+        Ok(plan)
     }
 
     fn name(&self) -> &str {
@@ -238,11 +218,58 @@ impl OptimizerRule for PushDownLimit {
     }
 }
 
-fn fetch_minus_skip(fetch: usize, skip: usize) -> usize {
-    if skip > fetch {
-        0
+fn push_down_join(join: &Join, limit: usize) -> Option<Join> {
+    use JoinType::*;
+
+    fn is_no_join_condition(join: &Join) -> bool {
+        join.on.is_empty() && join.filter.is_none()
+    }
+
+    let (left_limit, right_limit) = if is_no_join_condition(join) {
+        match join.join_type {
+            Left | Right | Full => (Some(limit), Some(limit)),
+            LeftAnti | LeftSemi => (Some(limit), None),
+            RightAnti | RightSemi => (None, Some(limit)),
+            Inner => (None, None),
+        }
     } else {
-        fetch - skip
+        match join.join_type {
+            Left => (Some(limit), None),
+            Right => (None, Some(limit)),
+            _ => (None, None),
+        }
+    };
+
+    match (left_limit, right_limit) {
+        (None, None) => None,
+        _ => {
+            let left = match left_limit {
+                Some(limit) => LogicalPlan::Limit(Limit {
+                    skip: 0,
+                    fetch: Some(limit),
+                    input: Arc::new((*join.left).clone()),
+                }),
+                None => (*join.left).clone(),
+            };
+            let right = match right_limit {
+                Some(limit) => LogicalPlan::Limit(Limit {
+                    skip: 0,
+                    fetch: Some(limit),
+                    input: Arc::new((*join.right).clone()),
+                }),
+                None => (*join.right).clone(),
+            };
+            Some(Join {
+                left: Arc::new(left),
+                right: Arc::new(right),
+                on: join.on.clone(),
+                filter: join.filter.clone(),
+                join_type: join.join_type,
+                join_constraint: join.join_constraint,
+                schema: join.schema.clone(),
+                null_equals_null: join.null_equals_null,
+            })
+        }
     }
 }