You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by ag...@apache.org on 2022/08/25 17:24:06 UTC
[arrow-datafusion] branch master updated: Use `ExprRewriter` in `pre_cast_lit_in_comparison` (#3260)
This is an automated email from the ASF dual-hosted git repository.
agrove pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/master by this push:
new 5ee52d08e Use `ExprRewriter` in `pre_cast_lit_in_comparison` (#3260)
5ee52d08e is described below
commit 5ee52d08ec83e39da6f5e7b3567fa112d3f022d4
Author: Andy Grove <an...@gmail.com>
AuthorDate: Thu Aug 25 11:24:00 2022 -0600
Use `ExprRewriter` in `pre_cast_lit_in_comparison` (#3260)
* Use ExprRewriter in pre_cast_lit_in_comparison.rs
* remove manual recursion and add a nested test case
---
.../optimizer/src/pre_cast_lit_in_comparison.rs | 172 +++++++++++++--------
1 file changed, 107 insertions(+), 65 deletions(-)
diff --git a/datafusion/optimizer/src/pre_cast_lit_in_comparison.rs b/datafusion/optimizer/src/pre_cast_lit_in_comparison.rs
index 0c16f7921..68c738ca8 100644
--- a/datafusion/optimizer/src/pre_cast_lit_in_comparison.rs
+++ b/datafusion/optimizer/src/pre_cast_lit_in_comparison.rs
@@ -20,6 +20,7 @@
use crate::{OptimizerConfig, OptimizerRule};
use arrow::datatypes::DataType;
use datafusion_common::{DFSchemaRef, DataFusionError, Result, ScalarValue};
+use datafusion_expr::expr_rewriter::{ExprRewritable, ExprRewriter, RewriteRecursion};
use datafusion_expr::utils::from_plan;
use datafusion_expr::{binary_expr, lit, Expr, ExprSchemable, LogicalPlan, Operator};
@@ -74,79 +75,92 @@ fn optimize(plan: &LogicalPlan) -> Result<LogicalPlan> {
.collect::<Result<Vec<_>>>()?;
let schema = plan.schema();
+
+ let mut expr_rewriter = PreCastLitExprRewriter {
+ schema: schema.clone(),
+ };
+
let new_exprs = plan
.expressions()
.into_iter()
- .map(|expr| visit_expr(expr, schema))
+ .map(|expr| expr.rewrite(&mut expr_rewriter))
.collect::<Result<Vec<_>>>()?;
from_plan(plan, new_exprs.as_slice(), new_inputs.as_slice())
}
-// Visit all type of expr, if the current has child expr, the child expr needed to visit first.
-fn visit_expr(expr: Expr, schema: &DFSchemaRef) -> Result<Expr> {
- // traverse the expr by dfs
- match &expr {
- Expr::BinaryExpr { left, op, right } => {
- // dfs visit the left and right expr
- let left = visit_expr(*left.clone(), schema)?;
- let right = visit_expr(*right.clone(), schema)?;
- let left_type = left.get_type(schema);
- let right_type = right.get_type(schema);
- // can't get the data type, just return the expr
- if left_type.is_err() || right_type.is_err() {
- return Ok(expr.clone());
- }
- let left_type = left_type.unwrap();
- let right_type = right_type.unwrap();
- if !left_type.eq(&right_type)
- && is_support_data_type(&left_type)
- && is_support_data_type(&right_type)
- && is_comparison_op(op)
- {
- match (&left, &right) {
- (Expr::Literal(_), Expr::Literal(_)) => {
- // do nothing
- }
- (Expr::Literal(left_lit_value), _)
- if can_integer_literal_cast_to_type(
- left_lit_value,
- &right_type,
- )? =>
- {
- // cast the left literal to the right type
- return Ok(binary_expr(
- cast_to_other_scalar_expr(left_lit_value, &right_type)?,
- *op,
- right,
- ));
- }
- (_, Expr::Literal(right_lit_value))
- if can_integer_literal_cast_to_type(
- right_lit_value,
- &left_type,
- )
- .unwrap() =>
- {
- // cast the right literal to the left type
- return Ok(binary_expr(
- left,
- *op,
- cast_to_other_scalar_expr(right_lit_value, &left_type)?,
- ));
- }
- (_, _) => {
- // do nothing
- }
- };
+struct PreCastLitExprRewriter {
+ schema: DFSchemaRef,
+}
+
+impl ExprRewriter for PreCastLitExprRewriter {
+ fn pre_visit(&mut self, _expr: &Expr) -> Result<RewriteRecursion> {
+ Ok(RewriteRecursion::Continue)
+ }
+
+ fn mutate(&mut self, expr: Expr) -> Result<Expr> {
+ // traverse the expr by dfs
+ match &expr {
+ Expr::BinaryExpr { left, op, right } => {
+ let left = left.as_ref().clone();
+ let right = right.as_ref().clone();
+ let left_type = left.get_type(&self.schema);
+ let right_type = right.get_type(&self.schema);
+ // can't get the data type, just return the expr
+ if left_type.is_err() || right_type.is_err() {
+ return Ok(expr.clone());
+ }
+ let left_type = left_type?;
+ let right_type = right_type?;
+ if !left_type.eq(&right_type)
+ && is_support_data_type(&left_type)
+ && is_support_data_type(&right_type)
+ && is_comparison_op(op)
+ {
+ match (&left, &right) {
+ (Expr::Literal(_), Expr::Literal(_)) => {
+ // do nothing
+ }
+ (Expr::Literal(left_lit_value), _)
+ if can_integer_literal_cast_to_type(
+ left_lit_value,
+ &right_type,
+ )? =>
+ {
+ // cast the left literal to the right type
+ return Ok(binary_expr(
+ cast_to_other_scalar_expr(left_lit_value, &right_type)?,
+ *op,
+ right,
+ ));
+ }
+ (_, Expr::Literal(right_lit_value))
+ if can_integer_literal_cast_to_type(
+ right_lit_value,
+ &left_type,
+ )
+ .unwrap() =>
+ {
+ // cast the right literal to the left type
+ return Ok(binary_expr(
+ left,
+ *op,
+ cast_to_other_scalar_expr(right_lit_value, &left_type)?,
+ ));
+ }
+ (_, _) => {
+ // do nothing
+ }
+ };
+ }
+ // return the new binary op
+ Ok(binary_expr(left, *op, right))
}
- // return the new binary op
- Ok(binary_expr(left, *op, right))
+ // TODO: optimize in list
+ // Expr::InList { .. } => {}
+ // TODO: handle other expr type and dfs visit them
+ _ => Ok(expr),
}
- // TODO: optimize in list
- // Expr::InList { .. } => {}
- // TODO: handle other expr type and dfs visit them
- _ => Ok(expr),
}
}
@@ -245,9 +259,10 @@ fn can_integer_literal_cast_to_type(
#[cfg(test)]
mod tests {
- use crate::pre_cast_lit_in_comparison::visit_expr;
+ use crate::pre_cast_lit_in_comparison::PreCastLitExprRewriter;
use arrow::datatypes::DataType;
use datafusion_common::{DFField, DFSchema, DFSchemaRef, ScalarValue};
+ use datafusion_expr::expr_rewriter::ExprRewritable;
use datafusion_expr::{col, lit, Expr};
use std::collections::HashMap;
use std::sync::Arc;
@@ -292,8 +307,35 @@ mod tests {
assert_eq!(optimize_test(c1_lt_lit_null, &schema), expected);
}
+ #[test]
+ fn aliased() {
+ let schema = expr_test_schema();
+ // c1 < INT64(16) -> c1 < cast(INT32(16))
+ // the 16 is within the range of MAX(int32) and MIN(int32), we can cast the 16 to int32(16)
+ let expr_lt = col("c1").lt(lit(ScalarValue::Int64(Some(16)))).alias("x");
+ let expected = col("c1").lt(lit(ScalarValue::Int32(Some(16)))).alias("x");
+ assert_eq!(optimize_test(expr_lt, &schema), expected);
+ }
+
+ #[test]
+ fn nested() {
+ let schema = expr_test_schema();
+ // c1 < INT64(16) OR c1 > INT64(32) -> c1 < INT32(16) OR c1 > INT32(32)
+ // the 16 and 32 are within the range of MAX(int32) and MIN(int32), we can cast them to int32
+ let expr_lt = col("c1")
+ .lt(lit(ScalarValue::Int64(Some(16))))
+ .or(col("c1").gt(lit(ScalarValue::Int64(Some(32)))));
+ let expected = col("c1")
+ .lt(lit(ScalarValue::Int32(Some(16))))
+ .or(col("c1").gt(lit(ScalarValue::Int32(Some(32)))));
+ assert_eq!(optimize_test(expr_lt, &schema), expected);
+ }
+
fn optimize_test(expr: Expr, schema: &DFSchemaRef) -> Expr {
- visit_expr(expr, schema).unwrap()
+ let mut expr_rewriter = PreCastLitExprRewriter {
+ schema: schema.clone(),
+ };
+ expr.rewrite(&mut expr_rewriter).unwrap()
}
fn expr_test_schema() -> DFSchemaRef {