You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by ag...@apache.org on 2022/08/23 13:38:30 UTC
[arrow-datafusion] branch master updated: optimizer: add framework for the rule of pre-add cast to the literal in comparison binary (#3185)
This is an automated email from the ASF dual-hosted git repository.
agrove pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/master by this push:
new 9ecf277a3 optimizer: add framework for the rule of pre-add cast to the literal in comparison binary (#3185)
9ecf277a3 is described below
commit 9ecf277a396c300a2ddbd5b2b4ab46947a091a43
Author: Kun Liu <li...@apache.org>
AuthorDate: Tue Aug 23 21:38:24 2022 +0800
optimizer: add framework for the rule of pre-add cast to the literal in comparison binary (#3185)
* add rule pre add cast to literal
* address comments and fix clippy
* change panic to result
---
datafusion/core/src/execution/context.rs | 2 +
datafusion/core/tests/provider_filter_pushdown.rs | 34 ++-
datafusion/core/tests/sql/explain_analyze.rs | 44 +--
datafusion/core/tests/sql/subqueries.rs | 4 +-
datafusion/optimizer/src/lib.rs | 1 +
.../optimizer/src/pre_cast_lit_in_comparison.rs | 311 +++++++++++++++++++++
6 files changed, 370 insertions(+), 26 deletions(-)
diff --git a/datafusion/core/src/execution/context.rs b/datafusion/core/src/execution/context.rs
index 9c9ed9526..7299ca7ac 100644
--- a/datafusion/core/src/execution/context.rs
+++ b/datafusion/core/src/execution/context.rs
@@ -106,6 +106,7 @@ use datafusion_optimizer::decorrelate_scalar_subquery::DecorrelateScalarSubquery
use datafusion_optimizer::decorrelate_where_exists::DecorrelateWhereExists;
use datafusion_optimizer::decorrelate_where_in::DecorrelateWhereIn;
use datafusion_optimizer::filter_null_join_keys::FilterNullJoinKeys;
+use datafusion_optimizer::pre_cast_lit_in_comparison::PreCastLitInComparisonExpressions;
use datafusion_optimizer::rewrite_disjunctive_predicate::RewriteDisjunctivePredicate;
use datafusion_sql::{
parser::DFParser,
@@ -1358,6 +1359,7 @@ impl SessionState {
// Simplify expressions first to maximize the chance
// of applying other optimizations
Arc::new(SimplifyExpressions::new()),
+ Arc::new(PreCastLitInComparisonExpressions::new()),
Arc::new(DecorrelateWhereExists::new()),
Arc::new(DecorrelateWhereIn::new()),
Arc::new(DecorrelateScalarSubquery::new()),
diff --git a/datafusion/core/tests/provider_filter_pushdown.rs b/datafusion/core/tests/provider_filter_pushdown.rs
index 3ebfec996..8e6d695c9 100644
--- a/datafusion/core/tests/provider_filter_pushdown.rs
+++ b/datafusion/core/tests/provider_filter_pushdown.rs
@@ -31,6 +31,8 @@ use datafusion::physical_plan::{
};
use datafusion::prelude::*;
use datafusion::scalar::ScalarValue;
+use datafusion_common::DataFusionError;
+use std::ops::Deref;
use std::sync::Arc;
fn create_batch(value: i32, num_rows: usize) -> Result<RecordBatch> {
@@ -146,8 +148,36 @@ impl TableProvider for CustomProvider {
match &filters[0] {
Expr::BinaryExpr { right, .. } => {
let int_value = match &**right {
- Expr::Literal(ScalarValue::Int64(i)) => i.unwrap(),
- _ => unimplemented!(),
+ Expr::Literal(ScalarValue::Int8(Some(i))) => *i as i64,
+ Expr::Literal(ScalarValue::Int16(Some(i))) => *i as i64,
+ Expr::Literal(ScalarValue::Int32(Some(i))) => *i as i64,
+ Expr::Literal(ScalarValue::Int64(Some(i))) => *i as i64,
+ Expr::Cast { expr, data_type: _ } => match expr.deref() {
+ Expr::Literal(lit_value) => match lit_value {
+ ScalarValue::Int8(Some(v)) => *v as i64,
+ ScalarValue::Int16(Some(v)) => *v as i64,
+ ScalarValue::Int32(Some(v)) => *v as i64,
+ ScalarValue::Int64(Some(v)) => *v,
+ other_value => {
+ return Err(DataFusionError::NotImplemented(format!(
+ "Do not support value {:?}",
+ other_value
+ )))
+ }
+ },
+ other_expr => {
+ return Err(DataFusionError::NotImplemented(format!(
+ "Do not support expr {:?}",
+ other_expr
+ )))
+ }
+ },
+ other_expr => {
+ return Err(DataFusionError::NotImplemented(format!(
+ "Do not support expr {:?}",
+ other_expr
+ )))
+ }
};
Ok(Arc::new(CustomPlan {
diff --git a/datafusion/core/tests/sql/explain_analyze.rs b/datafusion/core/tests/sql/explain_analyze.rs
index 02db3e873..2b801ed01 100644
--- a/datafusion/core/tests/sql/explain_analyze.rs
+++ b/datafusion/core/tests/sql/explain_analyze.rs
@@ -271,8 +271,8 @@ async fn csv_explain_plans() {
let expected = vec![
"Explain [plan_type:Utf8, plan:Utf8]",
" Projection: #aggregate_test_100.c1 [c1:Utf8]",
- " Filter: #aggregate_test_100.c2 > Int64(10) [c1:Utf8, c2:Int32]",
- " TableScan: aggregate_test_100 projection=[c1, c2], partial_filters=[#aggregate_test_100.c2 > Int64(10)] [c1:Utf8, c2:Int32]",
+ " Filter: #aggregate_test_100.c2 > Int32(10) [c1:Utf8, c2:Int32]",
+ " TableScan: aggregate_test_100 projection=[c1, c2], partial_filters=[#aggregate_test_100.c2 > Int32(10)] [c1:Utf8, c2:Int32]",
];
let formatted = plan.display_indent_schema().to_string();
let actual: Vec<&str> = formatted.trim().lines().collect();
@@ -286,8 +286,8 @@ async fn csv_explain_plans() {
let expected = vec![
"Explain",
" Projection: #aggregate_test_100.c1",
- " Filter: #aggregate_test_100.c2 > Int64(10)",
- " TableScan: aggregate_test_100 projection=[c1, c2], partial_filters=[#aggregate_test_100.c2 > Int64(10)]",
+ " Filter: #aggregate_test_100.c2 > Int32(10)",
+ " TableScan: aggregate_test_100 projection=[c1, c2], partial_filters=[#aggregate_test_100.c2 > Int32(10)]",
];
let formatted = plan.display_indent().to_string();
let actual: Vec<&str> = formatted.trim().lines().collect();
@@ -307,9 +307,9 @@ async fn csv_explain_plans() {
" 2[shape=box label=\"Explain\"]",
" 3[shape=box label=\"Projection: #aggregate_test_100.c1\"]",
" 2 -> 3 [arrowhead=none, arrowtail=normal, dir=back]",
- " 4[shape=box label=\"Filter: #aggregate_test_100.c2 > Int64(10)\"]",
+ " 4[shape=box label=\"Filter: #aggregate_test_100.c2 > Int32(10)\"]",
" 3 -> 4 [arrowhead=none, arrowtail=normal, dir=back]",
- " 5[shape=box label=\"TableScan: aggregate_test_100 projection=[c1, c2], partial_filters=[#aggregate_test_100.c2 > Int64(10)]\"]",
+ " 5[shape=box label=\"TableScan: aggregate_test_100 projection=[c1, c2], partial_filters=[#aggregate_test_100.c2 > Int32(10)]\"]",
" 4 -> 5 [arrowhead=none, arrowtail=normal, dir=back]",
" }",
" subgraph cluster_6",
@@ -318,9 +318,9 @@ async fn csv_explain_plans() {
" 7[shape=box label=\"Explain\\nSchema: [plan_type:Utf8, plan:Utf8]\"]",
" 8[shape=box label=\"Projection: #aggregate_test_100.c1\\nSchema: [c1:Utf8]\"]",
" 7 -> 8 [arrowhead=none, arrowtail=normal, dir=back]",
- " 9[shape=box label=\"Filter: #aggregate_test_100.c2 > Int64(10)\\nSchema: [c1:Utf8, c2:Int32]\"]",
+ " 9[shape=box label=\"Filter: #aggregate_test_100.c2 > Int32(10)\\nSchema: [c1:Utf8, c2:Int32]\"]",
" 8 -> 9 [arrowhead=none, arrowtail=normal, dir=back]",
- " 10[shape=box label=\"TableScan: aggregate_test_100 projection=[c1, c2], partial_filters=[#aggregate_test_100.c2 > Int64(10)]\\nSchema: [c1:Utf8, c2:Int32]\"]",
+ " 10[shape=box label=\"TableScan: aggregate_test_100 projection=[c1, c2], partial_filters=[#aggregate_test_100.c2 > Int32(10)]\\nSchema: [c1:Utf8, c2:Int32]\"]",
" 9 -> 10 [arrowhead=none, arrowtail=normal, dir=back]",
" }",
"}",
@@ -349,7 +349,7 @@ async fn csv_explain_plans() {
// Since the plan contains path that are environmentally dependant (e.g. full path of the test file), only verify important content
assert_contains!(&actual, "logical_plan");
assert_contains!(&actual, "Projection: #aggregate_test_100.c1");
- assert_contains!(actual, "Filter: #aggregate_test_100.c2 > Int64(10)");
+ assert_contains!(actual, "Filter: #aggregate_test_100.c2 > Int32(10)");
}
#[tokio::test]
@@ -469,8 +469,8 @@ async fn csv_explain_verbose_plans() {
let expected = vec![
"Explain [plan_type:Utf8, plan:Utf8]",
" Projection: #aggregate_test_100.c1 [c1:Utf8]",
- " Filter: #aggregate_test_100.c2 > Int64(10) [c1:Utf8, c2:Int32]",
- " TableScan: aggregate_test_100 projection=[c1, c2], partial_filters=[#aggregate_test_100.c2 > Int64(10)] [c1:Utf8, c2:Int32]",
+ " Filter: #aggregate_test_100.c2 > Int32(10) [c1:Utf8, c2:Int32]",
+ " TableScan: aggregate_test_100 projection=[c1, c2], partial_filters=[#aggregate_test_100.c2 > Int32(10)] [c1:Utf8, c2:Int32]",
];
let formatted = plan.display_indent_schema().to_string();
let actual: Vec<&str> = formatted.trim().lines().collect();
@@ -484,8 +484,8 @@ async fn csv_explain_verbose_plans() {
let expected = vec![
"Explain",
" Projection: #aggregate_test_100.c1",
- " Filter: #aggregate_test_100.c2 > Int64(10)",
- " TableScan: aggregate_test_100 projection=[c1, c2], partial_filters=[#aggregate_test_100.c2 > Int64(10)]",
+ " Filter: #aggregate_test_100.c2 > Int32(10)",
+ " TableScan: aggregate_test_100 projection=[c1, c2], partial_filters=[#aggregate_test_100.c2 > Int32(10)]",
];
let formatted = plan.display_indent().to_string();
let actual: Vec<&str> = formatted.trim().lines().collect();
@@ -505,9 +505,9 @@ async fn csv_explain_verbose_plans() {
" 2[shape=box label=\"Explain\"]",
" 3[shape=box label=\"Projection: #aggregate_test_100.c1\"]",
" 2 -> 3 [arrowhead=none, arrowtail=normal, dir=back]",
- " 4[shape=box label=\"Filter: #aggregate_test_100.c2 > Int64(10)\"]",
+ " 4[shape=box label=\"Filter: #aggregate_test_100.c2 > Int32(10)\"]",
" 3 -> 4 [arrowhead=none, arrowtail=normal, dir=back]",
- " 5[shape=box label=\"TableScan: aggregate_test_100 projection=[c1, c2], partial_filters=[#aggregate_test_100.c2 > Int64(10)]\"]",
+ " 5[shape=box label=\"TableScan: aggregate_test_100 projection=[c1, c2], partial_filters=[#aggregate_test_100.c2 > Int32(10)]\"]",
" 4 -> 5 [arrowhead=none, arrowtail=normal, dir=back]",
" }",
" subgraph cluster_6",
@@ -516,9 +516,9 @@ async fn csv_explain_verbose_plans() {
" 7[shape=box label=\"Explain\\nSchema: [plan_type:Utf8, plan:Utf8]\"]",
" 8[shape=box label=\"Projection: #aggregate_test_100.c1\\nSchema: [c1:Utf8]\"]",
" 7 -> 8 [arrowhead=none, arrowtail=normal, dir=back]",
- " 9[shape=box label=\"Filter: #aggregate_test_100.c2 > Int64(10)\\nSchema: [c1:Utf8, c2:Int32]\"]",
+ " 9[shape=box label=\"Filter: #aggregate_test_100.c2 > Int32(10)\\nSchema: [c1:Utf8, c2:Int32]\"]",
" 8 -> 9 [arrowhead=none, arrowtail=normal, dir=back]",
- " 10[shape=box label=\"TableScan: aggregate_test_100 projection=[c1, c2], partial_filters=[#aggregate_test_100.c2 > Int64(10)]\\nSchema: [c1:Utf8, c2:Int32]\"]",
+ " 10[shape=box label=\"TableScan: aggregate_test_100 projection=[c1, c2], partial_filters=[#aggregate_test_100.c2 > Int32(10)]\\nSchema: [c1:Utf8, c2:Int32]\"]",
" 9 -> 10 [arrowhead=none, arrowtail=normal, dir=back]",
" }",
"}",
@@ -549,7 +549,7 @@ async fn csv_explain_verbose_plans() {
// important content
assert_contains!(&actual, "logical_plan after projection_push_down");
assert_contains!(&actual, "physical_plan");
- assert_contains!(&actual, "FilterExec: CAST(c2@1 AS Int64) > 10");
+ assert_contains!(&actual, "FilterExec: c2@1 > 10");
assert_contains!(actual, "ProjectionExec: expr=[c1@0 as c1]");
}
@@ -745,7 +745,7 @@ async fn csv_explain() {
// then execute the physical plan and return the final explain results
let ctx = SessionContext::new();
register_aggregate_csv_by_sql(&ctx).await;
- let sql = "EXPLAIN SELECT c1 FROM aggregate_test_100 where c2 > 10";
+ let sql = "EXPLAIN SELECT c1 FROM aggregate_test_100 where c2 > cast(10 as int)";
let actual = execute(&ctx, sql).await;
let actual = normalize_vec_for_explain(actual);
@@ -755,13 +755,13 @@ async fn csv_explain() {
vec![
"logical_plan",
"Projection: #aggregate_test_100.c1\
- \n Filter: #aggregate_test_100.c2 > Int64(10)\
- \n TableScan: aggregate_test_100 projection=[c1, c2], partial_filters=[#aggregate_test_100.c2 > Int64(10)]"
+ \n Filter: #aggregate_test_100.c2 > Int32(10)\
+ \n TableScan: aggregate_test_100 projection=[c1, c2], partial_filters=[#aggregate_test_100.c2 > Int32(10)]"
],
vec!["physical_plan",
"ProjectionExec: expr=[c1@0 as c1]\
\n CoalesceBatchesExec: target_batch_size=4096\
- \n FilterExec: CAST(c2@1 AS Int64) > 10\
+ \n FilterExec: c2@1 > 10\
\n RepartitionExec: partitioning=RoundRobinBatch(NUM_CORES)\
\n CsvExec: files=[ARROW_TEST_DATA/csv/aggregate_test_100.csv], has_header=true, limit=None, projection=[c1, c2]\
\n"
diff --git a/datafusion/core/tests/sql/subqueries.rs b/datafusion/core/tests/sql/subqueries.rs
index 4eaf921f6..d85a26932 100644
--- a/datafusion/core/tests/sql/subqueries.rs
+++ b/datafusion/core/tests/sql/subqueries.rs
@@ -147,8 +147,8 @@ order by s_acctbal desc, n_name, s_name, p_partkey;"#;
Inner Join: #supplier.s_nationkey = #nation.n_nationkey
Inner Join: #partsupp.ps_suppkey = #supplier.s_suppkey
Inner Join: #part.p_partkey = #partsupp.ps_partkey
- Filter: #part.p_size = Int64(15) AND #part.p_type LIKE Utf8("%BRASS")
- TableScan: part projection=[p_partkey, p_mfgr, p_type, p_size], partial_filters=[#part.p_size = Int64(15), #part.p_type LIKE Utf8("%BRASS")]
+ Filter: #part.p_size = Int32(15) AND #part.p_type LIKE Utf8("%BRASS")
+ TableScan: part projection=[p_partkey, p_mfgr, p_type, p_size], partial_filters=[#part.p_size = Int32(15), #part.p_type LIKE Utf8("%BRASS")]
TableScan: partsupp projection=[ps_partkey, ps_suppkey, ps_supplycost]
TableScan: supplier projection=[s_suppkey, s_name, s_address, s_nationkey, s_phone, s_acctbal, s_comment]
TableScan: nation projection=[n_nationkey, n_name, n_regionkey]
diff --git a/datafusion/optimizer/src/lib.rs b/datafusion/optimizer/src/lib.rs
index 6da67b6fc..60c450992 100644
--- a/datafusion/optimizer/src/lib.rs
+++ b/datafusion/optimizer/src/lib.rs
@@ -33,6 +33,7 @@ pub mod single_distinct_to_groupby;
pub mod subquery_filter_to_join;
pub mod utils;
+pub mod pre_cast_lit_in_comparison;
pub mod rewrite_disjunctive_predicate;
#[cfg(test)]
pub mod test;
diff --git a/datafusion/optimizer/src/pre_cast_lit_in_comparison.rs b/datafusion/optimizer/src/pre_cast_lit_in_comparison.rs
new file mode 100644
index 000000000..0c16f7921
--- /dev/null
+++ b/datafusion/optimizer/src/pre_cast_lit_in_comparison.rs
@@ -0,0 +1,311 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+//! Pre-cast literal binary comparison rule can be only used to the binary comparison expr.
+//! It can reduce adding the `Expr::Cast` to the expr instead of adding the `Expr::Cast` to literal expr.
+use crate::{OptimizerConfig, OptimizerRule};
+use arrow::datatypes::DataType;
+use datafusion_common::{DFSchemaRef, DataFusionError, Result, ScalarValue};
+use datafusion_expr::utils::from_plan;
+use datafusion_expr::{binary_expr, lit, Expr, ExprSchemable, LogicalPlan, Operator};
+
+/// The rule can be only used to the numeric binary comparison with literal expr, like below pattern:
+/// `left_expr comparison_op literal_expr` or `literal_expr comparison_op right_expr`.
+/// The data type of two sides must be signed numeric type now, and will support more data type later.
+///
+/// If the binary comparison expr match above rules, the optimizer will check if the value of `literal`
+/// is in within range(min,max) which is the range(min,max) of the data type for `left_expr` or `right_expr`.
+///
+/// If this true, the literal expr will be casted to the data type of expr on the other side, and the result of
+/// binary comparison will be `left_expr comparison_op cast(literal_expr, left_data_type)` or
+/// `cast(literal_expr, right_data_type) comparison_op right_expr`. For better optimization,
+/// the expr of `cast(literal_expr, target_type)` will be precomputed and converted to the new expr `new_literal_expr`
+/// which data type is `target_type`.
+/// If this false, do nothing.
+///
+/// This is inspired by the optimizer rule `UnwrapCastInBinaryComparison` of Spark.
+/// # Example
+///
+/// `Filter: c1 > INT64(10)` will be optimized to `Filter: c1 > CAST(INT64(10) AS INT32),
+/// and continue to be converted to `Filter: c1 > INT32(10)`, if the DataType of c1 is INT32.
+///
+#[derive(Default)]
+pub struct PreCastLitInComparisonExpressions {}
+
+impl PreCastLitInComparisonExpressions {
+ pub fn new() -> Self {
+ Self::default()
+ }
+}
+
+impl OptimizerRule for PreCastLitInComparisonExpressions {
+ fn optimize(
+ &self,
+ plan: &LogicalPlan,
+ _optimizer_config: &mut OptimizerConfig,
+ ) -> Result<LogicalPlan> {
+ optimize(plan)
+ }
+
+ fn name(&self) -> &str {
+ "pre_cast_lit_in_comparison"
+ }
+}
+
+fn optimize(plan: &LogicalPlan) -> Result<LogicalPlan> {
+ let new_inputs = plan
+ .inputs()
+ .iter()
+ .map(|input| optimize(input))
+ .collect::<Result<Vec<_>>>()?;
+
+ let schema = plan.schema();
+ let new_exprs = plan
+ .expressions()
+ .into_iter()
+ .map(|expr| visit_expr(expr, schema))
+ .collect::<Result<Vec<_>>>()?;
+
+ from_plan(plan, new_exprs.as_slice(), new_inputs.as_slice())
+}
+
+// Visit all type of expr, if the current has child expr, the child expr needed to visit first.
+fn visit_expr(expr: Expr, schema: &DFSchemaRef) -> Result<Expr> {
+ // traverse the expr by dfs
+ match &expr {
+ Expr::BinaryExpr { left, op, right } => {
+ // dfs visit the left and right expr
+ let left = visit_expr(*left.clone(), schema)?;
+ let right = visit_expr(*right.clone(), schema)?;
+ let left_type = left.get_type(schema);
+ let right_type = right.get_type(schema);
+ // can't get the data type, just return the expr
+ if left_type.is_err() || right_type.is_err() {
+ return Ok(expr.clone());
+ }
+ let left_type = left_type.unwrap();
+ let right_type = right_type.unwrap();
+ if !left_type.eq(&right_type)
+ && is_support_data_type(&left_type)
+ && is_support_data_type(&right_type)
+ && is_comparison_op(op)
+ {
+ match (&left, &right) {
+ (Expr::Literal(_), Expr::Literal(_)) => {
+ // do nothing
+ }
+ (Expr::Literal(left_lit_value), _)
+ if can_integer_literal_cast_to_type(
+ left_lit_value,
+ &right_type,
+ )? =>
+ {
+ // cast the left literal to the right type
+ return Ok(binary_expr(
+ cast_to_other_scalar_expr(left_lit_value, &right_type)?,
+ *op,
+ right,
+ ));
+ }
+ (_, Expr::Literal(right_lit_value))
+ if can_integer_literal_cast_to_type(
+ right_lit_value,
+ &left_type,
+ )
+ .unwrap() =>
+ {
+ // cast the right literal to the left type
+ return Ok(binary_expr(
+ left,
+ *op,
+ cast_to_other_scalar_expr(right_lit_value, &left_type)?,
+ ));
+ }
+ (_, _) => {
+ // do nothing
+ }
+ };
+ }
+ // return the new binary op
+ Ok(binary_expr(left, *op, right))
+ }
+ // TODO: optimize in list
+ // Expr::InList { .. } => {}
+ // TODO: handle other expr type and dfs visit them
+ _ => Ok(expr),
+ }
+}
+
+fn cast_to_other_scalar_expr(
+ origin_value: &ScalarValue,
+ target_type: &DataType,
+) -> Result<Expr> {
+ // null case
+ if origin_value.is_null() {
+ // if the origin value is null, just convert to another type of null value
+ // The target type must be satisfied `is_support_data_type` method, we can unwrap safely
+ return Ok(lit(ScalarValue::try_from(target_type).unwrap()));
+ }
+ // no null case
+ let value: i64 = match origin_value {
+ ScalarValue::Int8(Some(v)) => *v as i64,
+ ScalarValue::Int16(Some(v)) => *v as i64,
+ ScalarValue::Int32(Some(v)) => *v as i64,
+ ScalarValue::Int64(Some(v)) => *v as i64,
+ other_value => {
+ return Err(DataFusionError::Internal(format!(
+ "Invalid type and value {}",
+ other_value
+ )))
+ }
+ };
+ Ok(lit(match target_type {
+ DataType::Int8 => ScalarValue::Int8(Some(value as i8)),
+ DataType::Int16 => ScalarValue::Int16(Some(value as i16)),
+ DataType::Int32 => ScalarValue::Int32(Some(value as i32)),
+ DataType::Int64 => ScalarValue::Int64(Some(value)),
+ other_type => {
+ return Err(DataFusionError::Internal(format!(
+ "Invalid target data type {:?}",
+ other_type
+ )))
+ }
+ }))
+}
+
+fn is_comparison_op(op: &Operator) -> bool {
+ matches!(
+ op,
+ Operator::Eq
+ | Operator::NotEq
+ | Operator::Gt
+ | Operator::GtEq
+ | Operator::Lt
+ | Operator::LtEq
+ )
+}
+
+fn is_support_data_type(data_type: &DataType) -> bool {
+ // TODO support decimal with other data type
+ matches!(
+ data_type,
+ DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64
+ )
+}
+
+fn can_integer_literal_cast_to_type(
+ integer_lit_value: &ScalarValue,
+ target_type: &DataType,
+) -> Result<bool> {
+ if integer_lit_value.is_null() {
+ // null value can be cast to any type of null value
+ return Ok(true);
+ }
+ let (target_min, target_max) = match target_type {
+ DataType::Int8 => (i8::MIN as i128, i8::MAX as i128),
+ DataType::Int16 => (i16::MIN as i128, i16::MAX as i128),
+ DataType::Int32 => (i32::MIN as i128, i32::MAX as i128),
+ DataType::Int64 => (i64::MIN as i128, i64::MAX as i128),
+ other_type => {
+ return Err(DataFusionError::Internal(format!(
+ "Error target data type {:?}",
+ other_type
+ )))
+ }
+ };
+ let lit_value = match integer_lit_value {
+ ScalarValue::Int8(Some(v)) => *v as i128,
+ ScalarValue::Int16(Some(v)) => *v as i128,
+ ScalarValue::Int32(Some(v)) => *v as i128,
+ ScalarValue::Int64(Some(v)) => *v as i128,
+ other_value => {
+ return Err(DataFusionError::Internal(format!(
+ "Invalid literal value {:?}",
+ other_value
+ )))
+ }
+ };
+
+ Ok(lit_value >= target_min && lit_value <= target_max)
+}
+
+#[cfg(test)]
+mod tests {
+ use crate::pre_cast_lit_in_comparison::visit_expr;
+ use arrow::datatypes::DataType;
+ use datafusion_common::{DFField, DFSchema, DFSchemaRef, ScalarValue};
+ use datafusion_expr::{col, lit, Expr};
+ use std::collections::HashMap;
+ use std::sync::Arc;
+
+ #[test]
+ fn test_not_cast_lit_comparison() {
+ let schema = expr_test_schema();
+ // INT8(NULL) < INT32(12)
+ let lit_lt_lit =
+ lit(ScalarValue::Int8(None)).lt(lit(ScalarValue::Int32(Some(12))));
+ assert_eq!(optimize_test(lit_lt_lit.clone(), &schema), lit_lt_lit);
+ // INT32(c1) > INT64(c2)
+ let c1_gt_c2 = col("c1").gt(col("c2"));
+ assert_eq!(optimize_test(c1_gt_c2.clone(), &schema), c1_gt_c2);
+
+ // INT32(c1) < INT32(16), the type is same
+ let expr_lt = col("c1").lt(lit(ScalarValue::Int32(Some(16))));
+ assert_eq!(optimize_test(expr_lt.clone(), &schema), expr_lt);
+
+ // the 99999999999 is not within the range of MAX(int32) and MIN(int32), we don't cast the lit(99999999999) to int32 type
+ let expr_lt = col("c1").lt(lit(ScalarValue::Int64(Some(99999999999))));
+ assert_eq!(optimize_test(expr_lt.clone(), &schema), expr_lt);
+ }
+
+ #[test]
+ fn test_pre_cast_lit_comparison() {
+ let schema = expr_test_schema();
+ // c1 < INT64(16) -> c1 < cast(INT32(16))
+ // the 16 is within the range of MAX(int32) and MIN(int32), we can cast the 16 to int32(16)
+ let expr_lt = col("c1").lt(lit(ScalarValue::Int64(Some(16))));
+ let expected = col("c1").lt(lit(ScalarValue::Int32(Some(16))));
+ assert_eq!(optimize_test(expr_lt, &schema), expected);
+
+ // INT64(c2) = INT32(16) => INT64(c2) = INT64(16)
+ let c2_eq_lit = col("c2").eq(lit(ScalarValue::Int32(Some(16))));
+ let expected = col("c2").eq(lit(ScalarValue::Int64(Some(16))));
+ assert_eq!(optimize_test(c2_eq_lit, &schema), expected);
+
+ // INT32(c1) < INT64(NULL) => INT32(c1) < INT32(NULL)
+ let c1_lt_lit_null = col("c1").lt(lit(ScalarValue::Int64(None)));
+ let expected = col("c1").lt(lit(ScalarValue::Int32(None)));
+ assert_eq!(optimize_test(c1_lt_lit_null, &schema), expected);
+ }
+
+ fn optimize_test(expr: Expr, schema: &DFSchemaRef) -> Expr {
+ visit_expr(expr, schema).unwrap()
+ }
+
+ fn expr_test_schema() -> DFSchemaRef {
+ Arc::new(
+ DFSchema::new_with_metadata(
+ vec![
+ DFField::new(None, "c1", DataType::Int32, false),
+ DFField::new(None, "c2", DataType::Int64, false),
+ ],
+ HashMap::new(),
+ )
+ .unwrap(),
+ )
+ }
+}