You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by ag...@apache.org on 2022/04/28 17:49:39 UTC

[arrow-datafusion] branch master updated: rewrite approx_median to approx_percentile_cont while planning phase (#2262)

This is an automated email from the ASF dual-hosted git repository.

agrove pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/master by this push:
     new 7b61d5233 rewrite approx_median to approx_percentile_cont while planning phase (#2262)
7b61d5233 is described below

commit 7b61d5233fbe08e458f74bde3a958125db6abe00
Author: Eduard Karacharov <13...@users.noreply.github.com>
AuthorDate: Thu Apr 28 20:49:33 2022 +0300

    rewrite approx_median to approx_percentile_cont while planning phase (#2262)
---
 datafusion/core/src/execution/context.rs        |   5 -
 datafusion/core/src/optimizer/mod.rs            |   1 -
 datafusion/core/src/optimizer/to_approx_perc.rs | 161 ------------------------
 datafusion/core/src/sql/planner.rs              |  64 +++++++---
 4 files changed, 50 insertions(+), 181 deletions(-)

diff --git a/datafusion/core/src/execution/context.rs b/datafusion/core/src/execution/context.rs
index 5378e38da..33019ff4a 100644
--- a/datafusion/core/src/execution/context.rs
+++ b/datafusion/core/src/execution/context.rs
@@ -72,7 +72,6 @@ use crate::optimizer::optimizer::OptimizerRule;
 use crate::optimizer::projection_push_down::ProjectionPushDown;
 use crate::optimizer::simplify_expressions::SimplifyExpressions;
 use crate::optimizer::single_distinct_to_groupby::SingleDistinctToGroupBy;
-use crate::optimizer::to_approx_perc::ToApproxPerc;
 
 use crate::physical_optimizer::coalesce_batches::CoalesceBatches;
 use crate::physical_optimizer::merge_exec::AddCoalescePartitionsExec;
@@ -1210,10 +1209,6 @@ impl SessionState {
                 Arc::new(FilterPushDown::new()),
                 Arc::new(LimitPushDown::new()),
                 Arc::new(SingleDistinctToGroupBy::new()),
-                // ToApproxPerc must be applied last because
-                // it rewrites only the function and may interfere with
-                // other rules
-                Arc::new(ToApproxPerc::new()),
             ],
             physical_optimizers: vec![
                 Arc::new(AggregateStatistics::new()),
diff --git a/datafusion/core/src/optimizer/mod.rs b/datafusion/core/src/optimizer/mod.rs
index cddedfc8a..9f12ecea8 100644
--- a/datafusion/core/src/optimizer/mod.rs
+++ b/datafusion/core/src/optimizer/mod.rs
@@ -28,5 +28,4 @@ pub mod optimizer;
 pub mod projection_push_down;
 pub mod simplify_expressions;
 pub mod single_distinct_to_groupby;
-pub mod to_approx_perc;
 pub mod utils;
diff --git a/datafusion/core/src/optimizer/to_approx_perc.rs b/datafusion/core/src/optimizer/to_approx_perc.rs
deleted file mode 100644
index c33c3f676..000000000
--- a/datafusion/core/src/optimizer/to_approx_perc.rs
+++ /dev/null
@@ -1,161 +0,0 @@
-// Licensed to the Apache Software Foundation (ASF) under one
-// or more contributor license agreements.  See the NOTICE file
-// distributed with this work for additional information
-// regarding copyright ownership.  The ASF licenses this file
-// to you under the Apache License, Version 2.0 (the
-// "License"); you may not use this file except in compliance
-// with the License.  You may obtain a copy of the License at
-//
-//   http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing,
-// software distributed under the License is distributed on an
-// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-// KIND, either express or implied.  See the License for the
-// specific language governing permissions and limitations
-// under the License.
-
-//! espression/function to approx_percentile optimizer rule
-
-use crate::error::Result;
-use crate::execution::context::ExecutionProps;
-use crate::logical_plan::plan::Aggregate;
-use crate::logical_plan::{Expr, LogicalPlan};
-use crate::optimizer::optimizer::OptimizerRule;
-use crate::optimizer::utils;
-use crate::physical_plan::aggregates;
-use crate::scalar::ScalarValue;
-
-/// espression/function to approx_percentile optimizer rule
-///  ```text
-///    SELECT F1(s)
-///    ...
-///
-///    Into
-///
-///    SELECT APPROX_PERCENTILE_CONT(s, lit(n)) as "F1(s)"
-///    ...
-///  ```
-pub struct ToApproxPerc {}
-
-impl ToApproxPerc {
-    #[allow(missing_docs)]
-    pub fn new() -> Self {
-        Self {}
-    }
-}
-
-impl Default for ToApproxPerc {
-    fn default() -> Self {
-        Self::new()
-    }
-}
-
-fn optimize(plan: &LogicalPlan) -> Result<LogicalPlan> {
-    match plan {
-        LogicalPlan::Aggregate(Aggregate {
-            input,
-            aggr_expr,
-            schema,
-            group_expr,
-        }) => {
-            let new_aggr_expr = aggr_expr
-                .iter()
-                .map(|agg_expr| replace_with_percentile(agg_expr).unwrap())
-                .collect::<Vec<_>>();
-
-            Ok(LogicalPlan::Aggregate(Aggregate {
-                input: input.clone(),
-                aggr_expr: new_aggr_expr,
-                schema: schema.clone(),
-                group_expr: group_expr.clone(),
-            }))
-        }
-        _ => optimize_children(plan),
-    }
-}
-
-fn optimize_children(plan: &LogicalPlan) -> Result<LogicalPlan> {
-    let expr = plan.expressions();
-    let inputs = plan.inputs();
-    let new_inputs = inputs
-        .iter()
-        .map(|plan| optimize(plan))
-        .collect::<Result<Vec<_>>>()?;
-    utils::from_plan(plan, &expr, &new_inputs)
-}
-
-fn replace_with_percentile(expr: &Expr) -> Result<Expr> {
-    match expr {
-        Expr::AggregateFunction {
-            fun,
-            args,
-            distinct,
-        } => {
-            let mut new_args = args.clone();
-            let mut new_func = fun.clone();
-            if fun == &aggregates::AggregateFunction::ApproxMedian {
-                new_args.push(Expr::Literal(ScalarValue::Float64(Some(0.5_f64))));
-                new_func = aggregates::AggregateFunction::ApproxPercentileCont;
-            }
-
-            Ok(Expr::AggregateFunction {
-                fun: new_func,
-                args: new_args,
-                distinct: *distinct,
-            })
-        }
-        _ => Ok(expr.clone()),
-    }
-}
-
-impl OptimizerRule for ToApproxPerc {
-    fn optimize(
-        &self,
-        plan: &LogicalPlan,
-        _execution_props: &ExecutionProps,
-    ) -> Result<LogicalPlan> {
-        optimize(plan)
-    }
-    fn name(&self) -> &str {
-        "ToApproxPerc"
-    }
-}
-
-#[cfg(test)]
-mod tests {
-    use super::*;
-    use crate::logical_plan::{col, LogicalPlanBuilder};
-    use crate::physical_plan::aggregates;
-    use crate::test::*;
-
-    fn assert_optimized_plan_eq(plan: &LogicalPlan, expected: &str) {
-        let rule = ToApproxPerc::new();
-        let optimized_plan = rule
-            .optimize(plan, &ExecutionProps::new())
-            .expect("failed to optimize plan");
-        let formatted_plan = format!("{}", optimized_plan.display_indent_schema());
-        assert_eq!(formatted_plan, expected);
-    }
-
-    #[test]
-    fn median_1() -> Result<()> {
-        let table_scan = test_table_scan()?;
-        let expr = Expr::AggregateFunction {
-            fun: aggregates::AggregateFunction::ApproxMedian,
-            distinct: false,
-            args: vec![col("b")],
-        };
-
-        let plan = LogicalPlanBuilder::from(table_scan)
-            .aggregate(Vec::<Expr>::new(), vec![expr])?
-            .build()?;
-
-        // Rewrite to use approx_percentile
-        let expected = "Aggregate: groupBy=[[]], aggr=[[APPROXPERCENTILECONT(#test.b, Float64(0.5))]] [APPROXMEDIAN(test.b):UInt32;N]\
-                            \n  TableScan: test projection=None [a:UInt32, b:UInt32, c:UInt32]";
-
-        assert_optimized_plan_eq(&plan, expected);
-        Ok(())
-    }
-}
diff --git a/datafusion/core/src/sql/planner.rs b/datafusion/core/src/sql/planner.rs
index b45c66704..855606542 100644
--- a/datafusion/core/src/sql/planner.rs
+++ b/datafusion/core/src/sql/planner.rs
@@ -1817,15 +1817,17 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
                         WindowFunction::AggregateFunction(
                             aggregate_fun,
                         ) => {
+                            let (aggregate_fun, args) = self.aggregate_fn_to_expr(
+                                aggregate_fun,
+                                function,
+                                schema,
+                            )?;
+
                             return Ok(Expr::WindowFunction {
                                 fun: WindowFunction::AggregateFunction(
-                                    aggregate_fun.clone(),
-                                ),
-                                args: self.aggregate_fn_to_expr(
                                     aggregate_fun,
-                                    function,
-                                    schema,
-                                )?,
+                                ),
+                                args,
                                 partition_by,
                                 order_by,
                                 window_frame,
@@ -1850,7 +1852,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
                 // next, aggregate built-ins
                 if let Ok(fun) = aggregates::AggregateFunction::from_str(&name) {
                     let distinct = function.distinct;
-                    let args = self.aggregate_fn_to_expr(fun.clone(), function, schema)?;
+                    let (fun, args) = self.aggregate_fn_to_expr(fun, function, schema)?;
                     return Ok(Expr::AggregateFunction {
                         fun,
                         distinct,
@@ -1952,9 +1954,9 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
         fun: aggregates::AggregateFunction,
         function: sqlparser::ast::Function,
         schema: &DFSchema,
-    ) -> Result<Vec<Expr>> {
-        if fun == aggregates::AggregateFunction::Count {
-            function
+    ) -> Result<(aggregates::AggregateFunction, Vec<Expr>)> {
+        let args = match fun {
+            aggregates::AggregateFunction::Count => function
                 .args
                 .into_iter()
                 .map(|a| match a {
@@ -1964,10 +1966,24 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
                     FunctionArg::Unnamed(FunctionArgExpr::Wildcard) => Ok(lit(1_u8)),
                     _ => self.sql_fn_arg_to_logical_expr(a, schema),
                 })
-                .collect::<Result<Vec<Expr>>>()
-        } else {
-            self.function_args_to_expr(function.args, schema)
-        }
+                .collect::<Result<Vec<Expr>>>()?,
+            aggregates::AggregateFunction::ApproxMedian => function
+                .args
+                .into_iter()
+                .map(|a| self.sql_fn_arg_to_logical_expr(a, schema))
+                .chain(iter::once(Ok(lit(0.5_f64))))
+                .collect::<Result<Vec<Expr>>>()?,
+            _ => self.function_args_to_expr(function.args, schema)?,
+        };
+
+        let fun = match fun {
+            aggregates::AggregateFunction::ApproxMedian => {
+                aggregates::AggregateFunction::ApproxPercentileCont
+            }
+            _ => fun,
+        };
+
+        Ok((fun, args))
     }
 
     fn sql_interval_to_literal(
@@ -3349,6 +3365,15 @@ mod tests {
         quick_test(sql, expected);
     }
 
+    #[test]
+    fn select_approx_median() {
+        let sql = "SELECT approx_median(age) FROM person";
+        let expected = "Projection: #APPROXPERCENTILECONT(person.age,Float64(0.5))\
+                        \n  Aggregate: groupBy=[[]], aggr=[[APPROXPERCENTILECONT(#person.age, Float64(0.5))]]\
+                        \n    TableScan: person projection=None";
+        quick_test(sql, expected);
+    }
+
     #[test]
     fn select_scalar_func() {
         let sql = "SELECT sqrt(age) FROM person";
@@ -4105,6 +4130,17 @@ mod tests {
         quick_test(sql, expected);
     }
 
+    #[test]
+    fn approx_median_window() {
+        let sql =
+            "SELECT order_id, APPROX_MEDIAN(qty) OVER(PARTITION BY order_id) from orders";
+        let expected = "\
+        Projection: #orders.order_id, #APPROXPERCENTILECONT(orders.qty,Float64(0.5)) PARTITION BY [#orders.order_id]\
+        \n  WindowAggr: windowExpr=[[APPROXPERCENTILECONT(#orders.qty, Float64(0.5)) PARTITION BY [#orders.order_id]]]\
+        \n    TableScan: orders projection=None";
+        quick_test(sql, expected);
+    }
+
     #[test]
     fn select_typedstring() {
         let sql = "SELECT date '2020-12-10' AS date FROM person";