You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by xu...@apache.org on 2022/10/12 15:18:25 UTC
[arrow-datafusion] branch master updated: Consolidate and better tests for expression re-rewriting / aliasing (#3727)
This is an automated email from the ASF dual-hosted git repository.
xudong963 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/master by this push:
new e27e86baa Consolidate and better tests for expression re-rewriting / aliasing (#3727)
e27e86baa is described below
commit e27e86baa5a54dd4643055ee3b1521865d0977cd
Author: Andrew Lamb <an...@nerdnetworks.org>
AuthorDate: Wed Oct 12 11:18:17 2022 -0400
Consolidate and better tests for expression re-rewriting / aliasing (#3727)
---
datafusion/expr/src/expr.rs | 4 +-
datafusion/optimizer/src/type_coercion.rs | 33 ++----
.../optimizer/src/unwrap_cast_in_comparison.rs | 40 +------
datafusion/optimizer/src/utils.rs | 122 ++++++++++++++++++++-
4 files changed, 136 insertions(+), 63 deletions(-)
diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs
index c131682a8..16c3da078 100644
--- a/datafusion/expr/src/expr.rs
+++ b/datafusion/expr/src/expr.rs
@@ -482,8 +482,8 @@ impl Expr {
}
/// Return `self AS name` alias expression
- pub fn alias(self, name: &str) -> Expr {
- Expr::Alias(Box::new(self), name.to_owned())
+ pub fn alias(self, name: impl Into<String>) -> Expr {
+ Expr::Alias(Box::new(self), name.into())
}
/// Return `self IN <list>` if `negated` is false, otherwise
diff --git a/datafusion/optimizer/src/type_coercion.rs b/datafusion/optimizer/src/type_coercion.rs
index be4d70265..ad9314406 100644
--- a/datafusion/optimizer/src/type_coercion.rs
+++ b/datafusion/optimizer/src/type_coercion.rs
@@ -17,11 +17,12 @@
//! Optimizer rule for type validation and coercion
+use crate::utils::rewrite_preserving_name;
use crate::{OptimizerConfig, OptimizerRule};
use arrow::datatypes::DataType;
use datafusion_common::{DFSchema, DFSchemaRef, DataFusionError, Result};
use datafusion_expr::expr::Case;
-use datafusion_expr::expr_rewriter::{ExprRewritable, ExprRewriter, RewriteRecursion};
+use datafusion_expr::expr_rewriter::{ExprRewriter, RewriteRecursion};
use datafusion_expr::logical_plan::Subquery;
use datafusion_expr::type_coercion::binary::{coerce_types, comparison_coercion};
use datafusion_expr::type_coercion::functions::data_types;
@@ -91,30 +92,13 @@ fn optimize_internal(
schema: Arc::new(schema),
};
- let original_expr_names: Vec<Option<String>> = plan
- .expressions()
- .iter()
- .map(|expr| expr.name().ok())
- .collect();
-
let new_expr = plan
.expressions()
.into_iter()
- .zip(original_expr_names)
- .map(|(expr, original_name)| {
- let expr = expr.rewrite(&mut expr_rewrite)?;
-
+ .map(|expr| {
// ensure aggregate names don't change:
// https://github.com/apache/arrow-datafusion/issues/3555
- if matches!(expr, Expr::AggregateFunction { .. }) {
- if let Some((alias, name)) = original_name.zip(expr.name().ok()) {
- if alias != name {
- return Ok(expr.alias(&alias));
- }
- }
- }
-
- Ok(expr)
+ rewrite_preserving_name(expr, &mut expr_rewrite)
})
.collect::<Result<Vec<_>>>()?;
@@ -815,7 +799,8 @@ mod test {
let mut config = OptimizerConfig::default();
let plan = rule.optimize(&plan, &mut config)?;
assert_eq!(
- "Projection: a IN ([CAST(Int32(1) AS Int64), CAST(Int8(4) AS Int64), Int64(8)])\n EmptyRelation",
+ "Projection: a IN ([CAST(Int32(1) AS Int64), CAST(Int8(4) AS Int64), Int64(8)]) AS a IN (Map { iter: Iter([Int32(1), Int8(4), Int64(8)]) })\
+ \n EmptyRelation",
&format!("{:?}", plan)
);
// a in (1,4,8), a is decimal
@@ -833,7 +818,8 @@ mod test {
let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty, None)?);
let plan = rule.optimize(&plan, &mut config)?;
assert_eq!(
- "Projection: CAST(a AS Decimal128(24, 4)) IN ([CAST(Int32(1) AS Decimal128(24, 4)), CAST(Int8(4) AS Decimal128(24, 4)), CAST(Int64(8) AS Decimal128(24, 4))])\n EmptyRelation",
+ "Projection: CAST(a AS Decimal128(24, 4)) IN ([CAST(Int32(1) AS Decimal128(24, 4)), CAST(Int8(4) AS Decimal128(24, 4)), CAST(Int64(8) AS Decimal128(24, 4))]) AS a IN (Map { iter: Iter([Int32(1), Int8(4), Int64(8)]) })\
+ \n EmptyRelation",
&format!("{:?}", plan)
);
Ok(())
@@ -931,7 +917,8 @@ mod test {
let mut config = OptimizerConfig::default();
let plan = rule.optimize(&plan, &mut config).unwrap();
assert_eq!(
- "Projection: a LIKE CAST(NULL AS Utf8)\n EmptyRelation",
+ "Projection: a LIKE CAST(NULL AS Utf8) AS a LIKE NULL \
+ \n EmptyRelation",
&format!("{:?}", plan)
);
diff --git a/datafusion/optimizer/src/unwrap_cast_in_comparison.rs b/datafusion/optimizer/src/unwrap_cast_in_comparison.rs
index 8392e28a9..8d04357ad 100644
--- a/datafusion/optimizer/src/unwrap_cast_in_comparison.rs
+++ b/datafusion/optimizer/src/unwrap_cast_in_comparison.rs
@@ -18,12 +18,13 @@
//! Unwrap-cast binary comparison rule can be used to the binary/inlist comparison expr now, and other type
//! of expr can be added if needed.
//! This rule can reduce adding the `Expr::Cast` the expr instead of adding the `Expr::Cast` to literal expr.
+use crate::utils::rewrite_preserving_name;
use crate::{OptimizerConfig, OptimizerRule};
use arrow::datatypes::{
DataType, MAX_DECIMAL_FOR_EACH_PRECISION, MIN_DECIMAL_FOR_EACH_PRECISION,
};
use datafusion_common::{DFSchema, DFSchemaRef, DataFusionError, Result, ScalarValue};
-use datafusion_expr::expr_rewriter::{ExprRewritable, ExprRewriter, RewriteRecursion};
+use datafusion_expr::expr_rewriter::{ExprRewriter, RewriteRecursion};
use datafusion_expr::utils::from_plan;
use datafusion_expr::{
binary_expr, in_list, lit, Expr, ExprSchemable, LogicalPlan, Operator,
@@ -97,47 +98,12 @@ fn optimize(plan: &LogicalPlan) -> Result<LogicalPlan> {
let new_exprs = plan
.expressions()
.into_iter()
- .map(|expr| {
- let original_name = name_for_alias(&expr)?;
- let expr = expr.rewrite(&mut expr_rewriter)?;
- add_alias_if_changed(&original_name, expr)
- })
+ .map(|expr| rewrite_preserving_name(expr, &mut expr_rewriter))
.collect::<Result<Vec<_>>>()?;
from_plan(plan, new_exprs.as_slice(), new_inputs.as_slice())
}
-fn name_for_alias(expr: &Expr) -> Result<String> {
- match expr {
- Expr::Sort { expr, .. } => name_for_alias(expr),
- expr => expr.name(),
- }
-}
-
-fn add_alias_if_changed(original_name: &str, expr: Expr) -> Result<Expr> {
- let new_name = name_for_alias(&expr)?;
-
- if new_name == original_name {
- return Ok(expr);
- }
-
- Ok(match expr {
- Expr::Sort {
- expr,
- asc,
- nulls_first,
- } => {
- let expr = add_alias_if_changed(original_name, *expr)?;
- Expr::Sort {
- expr: Box::new(expr),
- asc,
- nulls_first,
- }
- }
- expr => expr.alias(original_name),
- })
-}
-
struct UnwrapCastExprRewriter {
schema: DFSchemaRef,
}
diff --git a/datafusion/optimizer/src/utils.rs b/datafusion/optimizer/src/utils.rs
index d962dd7b4..a1174276d 100644
--- a/datafusion/optimizer/src/utils.rs
+++ b/datafusion/optimizer/src/utils.rs
@@ -20,6 +20,7 @@
use crate::{OptimizerConfig, OptimizerRule};
use datafusion_common::Result;
use datafusion_common::{plan_err, Column, DFSchemaRef};
+use datafusion_expr::expr_rewriter::{ExprRewritable, ExprRewriter};
use datafusion_expr::expr_visitor::{ExprVisitable, ExpressionVisitor, Recursion};
use datafusion_expr::{
and, col, combine_filters,
@@ -315,13 +316,63 @@ pub fn alias_cols(cols: &[Column]) -> Vec<Expr> {
.collect()
}
+/// Rewrites `expr` using `rewriter`, ensuring that the output has the
+/// same name as `expr` prior to rewrite, adding an alias if necessary.
+///
+/// This is important when optimzing plans to ensure the the output
+/// schema of plan nodes don't change after optimization
+pub fn rewrite_preserving_name<R>(expr: Expr, rewriter: &mut R) -> Result<Expr>
+where
+ R: ExprRewriter<Expr>,
+{
+ let original_name = name_for_alias(&expr)?;
+ let expr = expr.rewrite(rewriter)?;
+ add_alias_if_changed(original_name, expr)
+}
+
+/// Return the name to use for the specific Expr, recursing into
+/// `Expr::Sort` as appropriate
+fn name_for_alias(expr: &Expr) -> Result<String> {
+ match expr {
+ Expr::Sort { expr, .. } => name_for_alias(expr),
+ expr => expr.name(),
+ }
+}
+
+/// Ensure `expr` has the name name as `original_name` by adding an
+/// alias if necessary.
+fn add_alias_if_changed(original_name: String, expr: Expr) -> Result<Expr> {
+ let new_name = name_for_alias(&expr)?;
+
+ if new_name == original_name {
+ return Ok(expr);
+ }
+
+ Ok(match expr {
+ Expr::Sort {
+ expr,
+ asc,
+ nulls_first,
+ } => {
+ let expr = add_alias_if_changed(original_name, *expr)?;
+ Expr::Sort {
+ expr: Box::new(expr),
+ asc,
+ nulls_first,
+ }
+ }
+ expr => expr.alias(original_name),
+ })
+}
+
#[cfg(test)]
mod tests {
use super::*;
use arrow::datatypes::DataType;
use datafusion_common::Column;
- use datafusion_expr::{col, utils::expr_to_columns};
+ use datafusion_expr::{col, lit, utils::expr_to_columns};
use std::collections::HashSet;
+ use std::ops::Add;
#[test]
fn test_collect_expr() -> Result<()> {
@@ -344,4 +395,73 @@ mod tests {
assert!(accum.contains(&Column::from_name("a")));
Ok(())
}
+
+ #[test]
+ fn test_rewrite_preserving_name() {
+ test_rewrite(col("a"), col("a"));
+
+ test_rewrite(col("a"), col("b"));
+
+ // cast data types
+ test_rewrite(
+ col("a"),
+ Expr::Cast {
+ expr: Box::new(col("a")),
+ data_type: DataType::Int32,
+ },
+ );
+
+ // change literal type from i32 to i64
+ test_rewrite(col("a").add(lit(1i32)), col("a").add(lit(1i64)));
+
+ // SortExpr a+1 ==> b + 2
+ test_rewrite(
+ Expr::Sort {
+ expr: Box::new(col("a").add(lit(1i32))),
+ asc: true,
+ nulls_first: false,
+ },
+ Expr::Sort {
+ expr: Box::new(col("b").add(lit(2i64))),
+ asc: true,
+ nulls_first: false,
+ },
+ );
+ }
+
+ /// rewrites `expr_from` to `rewrite_to` using
+ /// `rewrite_preserving_name` verifying the result is `expected_expr`
+ fn test_rewrite(expr_from: Expr, rewrite_to: Expr) {
+ struct TestRewriter {
+ rewrite_to: Expr,
+ }
+
+ impl ExprRewriter for TestRewriter {
+ fn mutate(&mut self, _: Expr) -> Result<Expr> {
+ Ok(self.rewrite_to.clone())
+ }
+ }
+
+ let mut rewriter = TestRewriter {
+ rewrite_to: rewrite_to.clone(),
+ };
+ let expr = rewrite_preserving_name(expr_from.clone(), &mut rewriter).unwrap();
+
+ let original_name = match &expr_from {
+ Expr::Sort { expr, .. } => expr.name(),
+ expr => expr.name(),
+ }
+ .unwrap();
+
+ let new_name = match &expr {
+ Expr::Sort { expr, .. } => expr.name(),
+ expr => expr.name(),
+ }
+ .unwrap();
+
+ assert_eq!(
+ original_name, new_name,
+ "mismatch rewriting expr_from: {expr_from} to {rewrite_to}"
+ )
+ }
}