You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by ag...@apache.org on 2022/10/05 18:29:03 UTC
[arrow-datafusion] branch master updated: Fix aggregate type coercion bug (#3710)
This is an automated email from the ASF dual-hosted git repository.
agrove pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/master by this push:
new 64669e997 Fix aggregate type coercion bug (#3710)
64669e997 is described below
commit 64669e997bce2f90b400614e97a87a60c5a25f3c
Author: Andrew Lamb <an...@nerdnetworks.org>
AuthorDate: Wed Oct 5 14:28:57 2022 -0400
Fix aggregate type coercion bug (#3710)
* Do not change output expr name in `UnwrapCastInComparison`
* Update
* Update test
* Fix regression
* Update tests
* clippy
---
datafusion/optimizer/src/optimizer.rs | 16 ++++++----
.../optimizer/src/unwrap_cast_in_comparison.rs | 37 +++++++++++++++++++++-
datafusion/optimizer/tests/integration-test.rs | 21 ++++++++++--
3 files changed, 65 insertions(+), 9 deletions(-)
diff --git a/datafusion/optimizer/src/optimizer.rs b/datafusion/optimizer/src/optimizer.rs
index 5ef5cfdd5..aa10cd8a7 100644
--- a/datafusion/optimizer/src/optimizer.rs
+++ b/datafusion/optimizer/src/optimizer.rs
@@ -178,16 +178,15 @@ impl Optimizer {
F: FnMut(&LogicalPlan, &dyn OptimizerRule),
{
let mut new_plan = plan.clone();
- debug!("Input logical plan:\n{}\n", plan.display_indent());
- trace!("Full input logical plan:\n{:?}", plan);
+ log_plan("Optimizer input", plan);
+
for rule in &self.rules {
let result = rule.optimize(&new_plan, optimizer_config);
match result {
Ok(plan) => {
new_plan = plan;
observer(&new_plan, rule.as_ref());
- debug!("After apply {} rule:\n", rule.name());
- debug!("Optimized logical plan:\n{}\n", new_plan.display_indent());
+ log_plan(rule.name(), &new_plan);
}
Err(ref e) => {
if optimizer_config.skip_failing_rules {
@@ -209,12 +208,17 @@ impl Optimizer {
}
}
}
- debug!("Optimized logical plan:\n{}\n", new_plan.display_indent());
- trace!("Full Optimized logical plan:\n {:?}", new_plan);
+ log_plan("Optimized plan", &new_plan);
Ok(new_plan)
}
}
+/// Log the plan in debug/tracing mode after some part of the optimizer runs
+fn log_plan(description: &str, plan: &LogicalPlan) {
+ debug!("{description}:\n{}\n", plan.display_indent());
+ trace!("{description}::\n{}\n", plan.display_indent_schema());
+}
+
#[cfg(test)]
mod tests {
use crate::optimizer::Optimizer;
diff --git a/datafusion/optimizer/src/unwrap_cast_in_comparison.rs b/datafusion/optimizer/src/unwrap_cast_in_comparison.rs
index 7d6858362..542c29bd7 100644
--- a/datafusion/optimizer/src/unwrap_cast_in_comparison.rs
+++ b/datafusion/optimizer/src/unwrap_cast_in_comparison.rs
@@ -97,12 +97,47 @@ fn optimize(plan: &LogicalPlan) -> Result<LogicalPlan> {
let new_exprs = plan
.expressions()
.into_iter()
- .map(|expr| expr.rewrite(&mut expr_rewriter))
+ .map(|expr| {
+ let original_name = name_for_alias(&expr)?;
+ let expr = expr.rewrite(&mut expr_rewriter)?;
+ add_alias_if_changed(&original_name, expr)
+ })
.collect::<Result<Vec<_>>>()?;
from_plan(plan, new_exprs.as_slice(), new_inputs.as_slice())
}
+fn name_for_alias(expr: &Expr) -> Result<String> {
+ match expr {
+ Expr::Sort { expr, .. } => name_for_alias(expr),
+ expr => expr.name(),
+ }
+}
+
+fn add_alias_if_changed(original_name: &str, expr: Expr) -> Result<Expr> {
+ let new_name = name_for_alias(&expr)?;
+
+ if new_name == original_name {
+ return Ok(expr);
+ }
+
+ Ok(match expr {
+ Expr::Sort {
+ expr,
+ asc,
+ nulls_first,
+ } => {
+ let expr = add_alias_if_changed(original_name, *expr)?;
+ Expr::Sort {
+ expr: Box::new(expr),
+ asc,
+ nulls_first,
+ }
+ }
+ expr => expr.alias(original_name),
+ })
+}
+
struct UnwrapCastExprRewriter {
schema: DFSchemaRef,
}
diff --git a/datafusion/optimizer/tests/integration-test.rs b/datafusion/optimizer/tests/integration-test.rs
index 2d9546f13..dc452af3b 100644
--- a/datafusion/optimizer/tests/integration-test.rs
+++ b/datafusion/optimizer/tests/integration-test.rs
@@ -29,13 +29,19 @@ use std::any::Any;
use std::collections::HashMap;
use std::sync::Arc;
+#[cfg(test)]
+#[ctor::ctor]
+fn init() {
+ let _ = env_logger::try_init();
+}
+
#[test]
fn case_when() -> Result<()> {
let sql = "SELECT CASE WHEN col_int32 > 0 THEN 1 ELSE 0 END FROM test";
let plan = test_sql(sql)?;
let expected =
- "Projection: CASE WHEN test.col_int32 > Int32(0) THEN Int64(1) ELSE Int64(0) END\
- \n TableScan: test projection=[col_int32]";
+ "Projection: CASE WHEN test.col_int32 > Int32(0) THEN Int64(1) ELSE Int64(0) END AS CASE WHEN test.col_int32 > Int64(0) THEN Int64(1) ELSE Int64(0) END\
+ \n TableScan: test projection=[col_int32]";
assert_eq!(expected, format!("{:?}", plan));
let sql = "SELECT CASE WHEN col_uint32 > 0 THEN 1 ELSE 0 END FROM test";
@@ -46,6 +52,17 @@ fn case_when() -> Result<()> {
Ok(())
}
+#[test]
+fn case_when_aggregate() -> Result<()> {
+ let sql = "SELECT col_utf8, SUM(CASE WHEN col_int32 > 0 THEN 1 ELSE 0 END) AS n FROM test GROUP BY col_utf8";
+ let plan = test_sql(sql)?;
+ let expected = "Projection: test.col_utf8, SUM(CASE WHEN test.col_int32 > Int64(0) THEN Int64(1) ELSE Int64(0) END) AS n\
+ \n Aggregate: groupBy=[[test.col_utf8]], aggr=[[SUM(CASE WHEN test.col_int32 > Int32(0) THEN Int64(1) ELSE Int64(0) END) AS SUM(CASE WHEN test.col_int32 > Int64(0) THEN Int64(1) ELSE Int64(0) END)]]\
+ \n TableScan: test projection=[col_int32, col_utf8]";
+ assert_eq!(expected, format!("{:?}", plan));
+ Ok(())
+}
+
#[test]
fn unsigned_target_type() -> Result<()> {
let sql = "SELECT * FROM test WHERE col_uint32 > 0";