You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by ag...@apache.org on 2022/10/05 18:29:03 UTC

[arrow-datafusion] branch master updated: Fix aggregate type coercion bug (#3710)

This is an automated email from the ASF dual-hosted git repository.

agrove pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/master by this push:
     new 64669e997 Fix aggregate type coercion bug (#3710)
64669e997 is described below

commit 64669e997bce2f90b400614e97a87a60c5a25f3c
Author: Andrew Lamb <an...@nerdnetworks.org>
AuthorDate: Wed Oct 5 14:28:57 2022 -0400

    Fix aggregate type coercion bug (#3710)
    
    * Do not change output expr name in `UnwrapCastInComparison`
    
    * Update
    
    * Update test
    
    * Fix regression
    
    * Update tests
    
    * clippy
---
 datafusion/optimizer/src/optimizer.rs              | 16 ++++++----
 .../optimizer/src/unwrap_cast_in_comparison.rs     | 37 +++++++++++++++++++++-
 datafusion/optimizer/tests/integration-test.rs     | 21 ++++++++++--
 3 files changed, 65 insertions(+), 9 deletions(-)

diff --git a/datafusion/optimizer/src/optimizer.rs b/datafusion/optimizer/src/optimizer.rs
index 5ef5cfdd5..aa10cd8a7 100644
--- a/datafusion/optimizer/src/optimizer.rs
+++ b/datafusion/optimizer/src/optimizer.rs
@@ -178,16 +178,15 @@ impl Optimizer {
         F: FnMut(&LogicalPlan, &dyn OptimizerRule),
     {
         let mut new_plan = plan.clone();
-        debug!("Input logical plan:\n{}\n", plan.display_indent());
-        trace!("Full input logical plan:\n{:?}", plan);
+        log_plan("Optimizer input", plan);
+
         for rule in &self.rules {
             let result = rule.optimize(&new_plan, optimizer_config);
             match result {
                 Ok(plan) => {
                     new_plan = plan;
                     observer(&new_plan, rule.as_ref());
-                    debug!("After apply {} rule:\n", rule.name());
-                    debug!("Optimized logical plan:\n{}\n", new_plan.display_indent());
+                    log_plan(rule.name(), &new_plan);
                 }
                 Err(ref e) => {
                     if optimizer_config.skip_failing_rules {
@@ -209,12 +208,17 @@ impl Optimizer {
                 }
             }
         }
-        debug!("Optimized logical plan:\n{}\n", new_plan.display_indent());
-        trace!("Full Optimized logical plan:\n {:?}", new_plan);
+        log_plan("Optimized plan", &new_plan);
         Ok(new_plan)
     }
 }
 
+/// Log the plan in debug/tracing mode after some part of the optimizer runs
+fn log_plan(description: &str, plan: &LogicalPlan) {
+    debug!("{description}:\n{}\n", plan.display_indent());
+    trace!("{description}::\n{}\n", plan.display_indent_schema());
+}
+
 #[cfg(test)]
 mod tests {
     use crate::optimizer::Optimizer;
diff --git a/datafusion/optimizer/src/unwrap_cast_in_comparison.rs b/datafusion/optimizer/src/unwrap_cast_in_comparison.rs
index 7d6858362..542c29bd7 100644
--- a/datafusion/optimizer/src/unwrap_cast_in_comparison.rs
+++ b/datafusion/optimizer/src/unwrap_cast_in_comparison.rs
@@ -97,12 +97,47 @@ fn optimize(plan: &LogicalPlan) -> Result<LogicalPlan> {
     let new_exprs = plan
         .expressions()
         .into_iter()
-        .map(|expr| expr.rewrite(&mut expr_rewriter))
+        .map(|expr| {
+            let original_name = name_for_alias(&expr)?;
+            let expr = expr.rewrite(&mut expr_rewriter)?;
+            add_alias_if_changed(&original_name, expr)
+        })
         .collect::<Result<Vec<_>>>()?;
 
     from_plan(plan, new_exprs.as_slice(), new_inputs.as_slice())
 }
 
+fn name_for_alias(expr: &Expr) -> Result<String> {
+    match expr {
+        Expr::Sort { expr, .. } => name_for_alias(expr),
+        expr => expr.name(),
+    }
+}
+
+fn add_alias_if_changed(original_name: &str, expr: Expr) -> Result<Expr> {
+    let new_name = name_for_alias(&expr)?;
+
+    if new_name == original_name {
+        return Ok(expr);
+    }
+
+    Ok(match expr {
+        Expr::Sort {
+            expr,
+            asc,
+            nulls_first,
+        } => {
+            let expr = add_alias_if_changed(original_name, *expr)?;
+            Expr::Sort {
+                expr: Box::new(expr),
+                asc,
+                nulls_first,
+            }
+        }
+        expr => expr.alias(original_name),
+    })
+}
+
 struct UnwrapCastExprRewriter {
     schema: DFSchemaRef,
 }
diff --git a/datafusion/optimizer/tests/integration-test.rs b/datafusion/optimizer/tests/integration-test.rs
index 2d9546f13..dc452af3b 100644
--- a/datafusion/optimizer/tests/integration-test.rs
+++ b/datafusion/optimizer/tests/integration-test.rs
@@ -29,13 +29,19 @@ use std::any::Any;
 use std::collections::HashMap;
 use std::sync::Arc;
 
+#[cfg(test)]
+#[ctor::ctor]
+fn init() {
+    let _ = env_logger::try_init();
+}
+
 #[test]
 fn case_when() -> Result<()> {
     let sql = "SELECT CASE WHEN col_int32 > 0 THEN 1 ELSE 0 END FROM test";
     let plan = test_sql(sql)?;
     let expected =
-        "Projection: CASE WHEN test.col_int32 > Int32(0) THEN Int64(1) ELSE Int64(0) END\
-    \n  TableScan: test projection=[col_int32]";
+        "Projection: CASE WHEN test.col_int32 > Int32(0) THEN Int64(1) ELSE Int64(0) END AS CASE WHEN test.col_int32 > Int64(0) THEN Int64(1) ELSE Int64(0) END\
+         \n  TableScan: test projection=[col_int32]";
     assert_eq!(expected, format!("{:?}", plan));
 
     let sql = "SELECT CASE WHEN col_uint32 > 0 THEN 1 ELSE 0 END FROM test";
@@ -46,6 +52,17 @@ fn case_when() -> Result<()> {
     Ok(())
 }
 
+#[test]
+fn case_when_aggregate() -> Result<()> {
+    let sql = "SELECT col_utf8, SUM(CASE WHEN col_int32 > 0 THEN 1 ELSE 0 END) AS n FROM test GROUP BY col_utf8";
+    let plan = test_sql(sql)?;
+    let expected = "Projection: test.col_utf8, SUM(CASE WHEN test.col_int32 > Int64(0) THEN Int64(1) ELSE Int64(0) END) AS n\
+                    \n  Aggregate: groupBy=[[test.col_utf8]], aggr=[[SUM(CASE WHEN test.col_int32 > Int32(0) THEN Int64(1) ELSE Int64(0) END) AS SUM(CASE WHEN test.col_int32 > Int64(0) THEN Int64(1) ELSE Int64(0) END)]]\
+                    \n    TableScan: test projection=[col_int32, col_utf8]";
+    assert_eq!(expected, format!("{:?}", plan));
+    Ok(())
+}
+
 #[test]
 fn unsigned_target_type() -> Result<()> {
     let sql = "SELECT * FROM test WHERE col_uint32 > 0";