You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by ag...@apache.org on 2022/08/23 13:38:30 UTC

[arrow-datafusion] branch master updated: optimizer: add framework for the rule of pre-add cast to the literal in comparison binary (#3185)

This is an automated email from the ASF dual-hosted git repository.

agrove pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/master by this push:
     new 9ecf277a3 optimizer: add framework for the rule of pre-add cast to the literal in comparison binary (#3185)
9ecf277a3 is described below

commit 9ecf277a396c300a2ddbd5b2b4ab46947a091a43
Author: Kun Liu <li...@apache.org>
AuthorDate: Tue Aug 23 21:38:24 2022 +0800

    optimizer: add framework for the rule of pre-add cast to the literal in comparison binary (#3185)
    
    * add rule pre add cast to literal
    
    * address comments and fix clippy
    
    * change panic to result
---
 datafusion/core/src/execution/context.rs           |   2 +
 datafusion/core/tests/provider_filter_pushdown.rs  |  34 ++-
 datafusion/core/tests/sql/explain_analyze.rs       |  44 +--
 datafusion/core/tests/sql/subqueries.rs            |   4 +-
 datafusion/optimizer/src/lib.rs                    |   1 +
 .../optimizer/src/pre_cast_lit_in_comparison.rs    | 311 +++++++++++++++++++++
 6 files changed, 370 insertions(+), 26 deletions(-)

diff --git a/datafusion/core/src/execution/context.rs b/datafusion/core/src/execution/context.rs
index 9c9ed9526..7299ca7ac 100644
--- a/datafusion/core/src/execution/context.rs
+++ b/datafusion/core/src/execution/context.rs
@@ -106,6 +106,7 @@ use datafusion_optimizer::decorrelate_scalar_subquery::DecorrelateScalarSubquery
 use datafusion_optimizer::decorrelate_where_exists::DecorrelateWhereExists;
 use datafusion_optimizer::decorrelate_where_in::DecorrelateWhereIn;
 use datafusion_optimizer::filter_null_join_keys::FilterNullJoinKeys;
+use datafusion_optimizer::pre_cast_lit_in_comparison::PreCastLitInComparisonExpressions;
 use datafusion_optimizer::rewrite_disjunctive_predicate::RewriteDisjunctivePredicate;
 use datafusion_sql::{
     parser::DFParser,
@@ -1358,6 +1359,7 @@ impl SessionState {
             // Simplify expressions first to maximize the chance
             // of applying other optimizations
             Arc::new(SimplifyExpressions::new()),
+            Arc::new(PreCastLitInComparisonExpressions::new()),
             Arc::new(DecorrelateWhereExists::new()),
             Arc::new(DecorrelateWhereIn::new()),
             Arc::new(DecorrelateScalarSubquery::new()),
diff --git a/datafusion/core/tests/provider_filter_pushdown.rs b/datafusion/core/tests/provider_filter_pushdown.rs
index 3ebfec996..8e6d695c9 100644
--- a/datafusion/core/tests/provider_filter_pushdown.rs
+++ b/datafusion/core/tests/provider_filter_pushdown.rs
@@ -31,6 +31,8 @@ use datafusion::physical_plan::{
 };
 use datafusion::prelude::*;
 use datafusion::scalar::ScalarValue;
+use datafusion_common::DataFusionError;
+use std::ops::Deref;
 use std::sync::Arc;
 
 fn create_batch(value: i32, num_rows: usize) -> Result<RecordBatch> {
@@ -146,8 +148,36 @@ impl TableProvider for CustomProvider {
         match &filters[0] {
             Expr::BinaryExpr { right, .. } => {
                 let int_value = match &**right {
-                    Expr::Literal(ScalarValue::Int64(i)) => i.unwrap(),
-                    _ => unimplemented!(),
+                    Expr::Literal(ScalarValue::Int8(Some(i))) => *i as i64,
+                    Expr::Literal(ScalarValue::Int16(Some(i))) => *i as i64,
+                    Expr::Literal(ScalarValue::Int32(Some(i))) => *i as i64,
+                    Expr::Literal(ScalarValue::Int64(Some(i))) => *i as i64,
+                    Expr::Cast { expr, data_type: _ } => match expr.deref() {
+                        Expr::Literal(lit_value) => match lit_value {
+                            ScalarValue::Int8(Some(v)) => *v as i64,
+                            ScalarValue::Int16(Some(v)) => *v as i64,
+                            ScalarValue::Int32(Some(v)) => *v as i64,
+                            ScalarValue::Int64(Some(v)) => *v,
+                            other_value => {
+                                return Err(DataFusionError::NotImplemented(format!(
+                                    "Do not support value {:?}",
+                                    other_value
+                                )))
+                            }
+                        },
+                        other_expr => {
+                            return Err(DataFusionError::NotImplemented(format!(
+                                "Do not support expr {:?}",
+                                other_expr
+                            )))
+                        }
+                    },
+                    other_expr => {
+                        return Err(DataFusionError::NotImplemented(format!(
+                            "Do not support expr {:?}",
+                            other_expr
+                        )))
+                    }
                 };
 
                 Ok(Arc::new(CustomPlan {
diff --git a/datafusion/core/tests/sql/explain_analyze.rs b/datafusion/core/tests/sql/explain_analyze.rs
index 02db3e873..2b801ed01 100644
--- a/datafusion/core/tests/sql/explain_analyze.rs
+++ b/datafusion/core/tests/sql/explain_analyze.rs
@@ -271,8 +271,8 @@ async fn csv_explain_plans() {
     let expected = vec![
         "Explain [plan_type:Utf8, plan:Utf8]",
         "  Projection: #aggregate_test_100.c1 [c1:Utf8]",
-        "    Filter: #aggregate_test_100.c2 > Int64(10) [c1:Utf8, c2:Int32]",
-        "      TableScan: aggregate_test_100 projection=[c1, c2], partial_filters=[#aggregate_test_100.c2 > Int64(10)] [c1:Utf8, c2:Int32]",
+        "    Filter: #aggregate_test_100.c2 > Int32(10) [c1:Utf8, c2:Int32]",
+        "      TableScan: aggregate_test_100 projection=[c1, c2], partial_filters=[#aggregate_test_100.c2 > Int32(10)] [c1:Utf8, c2:Int32]",
     ];
     let formatted = plan.display_indent_schema().to_string();
     let actual: Vec<&str> = formatted.trim().lines().collect();
@@ -286,8 +286,8 @@ async fn csv_explain_plans() {
     let expected = vec![
         "Explain",
         "  Projection: #aggregate_test_100.c1",
-        "    Filter: #aggregate_test_100.c2 > Int64(10)",
-        "      TableScan: aggregate_test_100 projection=[c1, c2], partial_filters=[#aggregate_test_100.c2 > Int64(10)]",
+        "    Filter: #aggregate_test_100.c2 > Int32(10)",
+        "      TableScan: aggregate_test_100 projection=[c1, c2], partial_filters=[#aggregate_test_100.c2 > Int32(10)]",
     ];
     let formatted = plan.display_indent().to_string();
     let actual: Vec<&str> = formatted.trim().lines().collect();
@@ -307,9 +307,9 @@ async fn csv_explain_plans() {
         "    2[shape=box label=\"Explain\"]",
         "    3[shape=box label=\"Projection: #aggregate_test_100.c1\"]",
         "    2 -> 3 [arrowhead=none, arrowtail=normal, dir=back]",
-        "    4[shape=box label=\"Filter: #aggregate_test_100.c2 > Int64(10)\"]",
+        "    4[shape=box label=\"Filter: #aggregate_test_100.c2 > Int32(10)\"]",
         "    3 -> 4 [arrowhead=none, arrowtail=normal, dir=back]",
-        "    5[shape=box label=\"TableScan: aggregate_test_100 projection=[c1, c2], partial_filters=[#aggregate_test_100.c2 > Int64(10)]\"]",
+        "    5[shape=box label=\"TableScan: aggregate_test_100 projection=[c1, c2], partial_filters=[#aggregate_test_100.c2 > Int32(10)]\"]",
         "    4 -> 5 [arrowhead=none, arrowtail=normal, dir=back]",
         "  }",
         "  subgraph cluster_6",
@@ -318,9 +318,9 @@ async fn csv_explain_plans() {
         "    7[shape=box label=\"Explain\\nSchema: [plan_type:Utf8, plan:Utf8]\"]",
         "    8[shape=box label=\"Projection: #aggregate_test_100.c1\\nSchema: [c1:Utf8]\"]",
         "    7 -> 8 [arrowhead=none, arrowtail=normal, dir=back]",
-        "    9[shape=box label=\"Filter: #aggregate_test_100.c2 > Int64(10)\\nSchema: [c1:Utf8, c2:Int32]\"]",
+        "    9[shape=box label=\"Filter: #aggregate_test_100.c2 > Int32(10)\\nSchema: [c1:Utf8, c2:Int32]\"]",
         "    8 -> 9 [arrowhead=none, arrowtail=normal, dir=back]",
-        "    10[shape=box label=\"TableScan: aggregate_test_100 projection=[c1, c2], partial_filters=[#aggregate_test_100.c2 > Int64(10)]\\nSchema: [c1:Utf8, c2:Int32]\"]",
+        "    10[shape=box label=\"TableScan: aggregate_test_100 projection=[c1, c2], partial_filters=[#aggregate_test_100.c2 > Int32(10)]\\nSchema: [c1:Utf8, c2:Int32]\"]",
         "    9 -> 10 [arrowhead=none, arrowtail=normal, dir=back]",
         "  }",
         "}",
@@ -349,7 +349,7 @@ async fn csv_explain_plans() {
     // Since the plan contains path that are environmentally dependant (e.g. full path of the test file), only verify important content
     assert_contains!(&actual, "logical_plan");
     assert_contains!(&actual, "Projection: #aggregate_test_100.c1");
-    assert_contains!(actual, "Filter: #aggregate_test_100.c2 > Int64(10)");
+    assert_contains!(actual, "Filter: #aggregate_test_100.c2 > Int32(10)");
 }
 
 #[tokio::test]
@@ -469,8 +469,8 @@ async fn csv_explain_verbose_plans() {
     let expected = vec![
         "Explain [plan_type:Utf8, plan:Utf8]",
         "  Projection: #aggregate_test_100.c1 [c1:Utf8]",
-        "    Filter: #aggregate_test_100.c2 > Int64(10) [c1:Utf8, c2:Int32]",
-        "      TableScan: aggregate_test_100 projection=[c1, c2], partial_filters=[#aggregate_test_100.c2 > Int64(10)] [c1:Utf8, c2:Int32]",
+        "    Filter: #aggregate_test_100.c2 > Int32(10) [c1:Utf8, c2:Int32]",
+        "      TableScan: aggregate_test_100 projection=[c1, c2], partial_filters=[#aggregate_test_100.c2 > Int32(10)] [c1:Utf8, c2:Int32]",
     ];
     let formatted = plan.display_indent_schema().to_string();
     let actual: Vec<&str> = formatted.trim().lines().collect();
@@ -484,8 +484,8 @@ async fn csv_explain_verbose_plans() {
     let expected = vec![
         "Explain",
         "  Projection: #aggregate_test_100.c1",
-        "    Filter: #aggregate_test_100.c2 > Int64(10)",
-        "      TableScan: aggregate_test_100 projection=[c1, c2], partial_filters=[#aggregate_test_100.c2 > Int64(10)]",
+        "    Filter: #aggregate_test_100.c2 > Int32(10)",
+        "      TableScan: aggregate_test_100 projection=[c1, c2], partial_filters=[#aggregate_test_100.c2 > Int32(10)]",
     ];
     let formatted = plan.display_indent().to_string();
     let actual: Vec<&str> = formatted.trim().lines().collect();
@@ -505,9 +505,9 @@ async fn csv_explain_verbose_plans() {
         "    2[shape=box label=\"Explain\"]",
         "    3[shape=box label=\"Projection: #aggregate_test_100.c1\"]",
         "    2 -> 3 [arrowhead=none, arrowtail=normal, dir=back]",
-        "    4[shape=box label=\"Filter: #aggregate_test_100.c2 > Int64(10)\"]",
+        "    4[shape=box label=\"Filter: #aggregate_test_100.c2 > Int32(10)\"]",
         "    3 -> 4 [arrowhead=none, arrowtail=normal, dir=back]",
-        "    5[shape=box label=\"TableScan: aggregate_test_100 projection=[c1, c2], partial_filters=[#aggregate_test_100.c2 > Int64(10)]\"]",
+        "    5[shape=box label=\"TableScan: aggregate_test_100 projection=[c1, c2], partial_filters=[#aggregate_test_100.c2 > Int32(10)]\"]",
         "    4 -> 5 [arrowhead=none, arrowtail=normal, dir=back]",
         "  }",
         "  subgraph cluster_6",
@@ -516,9 +516,9 @@ async fn csv_explain_verbose_plans() {
         "    7[shape=box label=\"Explain\\nSchema: [plan_type:Utf8, plan:Utf8]\"]",
         "    8[shape=box label=\"Projection: #aggregate_test_100.c1\\nSchema: [c1:Utf8]\"]",
         "    7 -> 8 [arrowhead=none, arrowtail=normal, dir=back]",
-        "    9[shape=box label=\"Filter: #aggregate_test_100.c2 > Int64(10)\\nSchema: [c1:Utf8, c2:Int32]\"]",
+        "    9[shape=box label=\"Filter: #aggregate_test_100.c2 > Int32(10)\\nSchema: [c1:Utf8, c2:Int32]\"]",
         "    8 -> 9 [arrowhead=none, arrowtail=normal, dir=back]",
-        "    10[shape=box label=\"TableScan: aggregate_test_100 projection=[c1, c2], partial_filters=[#aggregate_test_100.c2 > Int64(10)]\\nSchema: [c1:Utf8, c2:Int32]\"]",
+        "    10[shape=box label=\"TableScan: aggregate_test_100 projection=[c1, c2], partial_filters=[#aggregate_test_100.c2 > Int32(10)]\\nSchema: [c1:Utf8, c2:Int32]\"]",
         "    9 -> 10 [arrowhead=none, arrowtail=normal, dir=back]",
         "  }",
         "}",
@@ -549,7 +549,7 @@ async fn csv_explain_verbose_plans() {
     // important content
     assert_contains!(&actual, "logical_plan after projection_push_down");
     assert_contains!(&actual, "physical_plan");
-    assert_contains!(&actual, "FilterExec: CAST(c2@1 AS Int64) > 10");
+    assert_contains!(&actual, "FilterExec: c2@1 > 10");
     assert_contains!(actual, "ProjectionExec: expr=[c1@0 as c1]");
 }
 
@@ -745,7 +745,7 @@ async fn csv_explain() {
     // then execute the physical plan and return the final explain results
     let ctx = SessionContext::new();
     register_aggregate_csv_by_sql(&ctx).await;
-    let sql = "EXPLAIN SELECT c1 FROM aggregate_test_100 where c2 > 10";
+    let sql = "EXPLAIN SELECT c1 FROM aggregate_test_100 where c2 > cast(10 as int)";
     let actual = execute(&ctx, sql).await;
     let actual = normalize_vec_for_explain(actual);
 
@@ -755,13 +755,13 @@ async fn csv_explain() {
         vec![
             "logical_plan",
             "Projection: #aggregate_test_100.c1\
-             \n  Filter: #aggregate_test_100.c2 > Int64(10)\
-             \n    TableScan: aggregate_test_100 projection=[c1, c2], partial_filters=[#aggregate_test_100.c2 > Int64(10)]"
+             \n  Filter: #aggregate_test_100.c2 > Int32(10)\
+             \n    TableScan: aggregate_test_100 projection=[c1, c2], partial_filters=[#aggregate_test_100.c2 > Int32(10)]"
         ],
         vec!["physical_plan",
              "ProjectionExec: expr=[c1@0 as c1]\
               \n  CoalesceBatchesExec: target_batch_size=4096\
-              \n    FilterExec: CAST(c2@1 AS Int64) > 10\
+              \n    FilterExec: c2@1 > 10\
               \n      RepartitionExec: partitioning=RoundRobinBatch(NUM_CORES)\
               \n        CsvExec: files=[ARROW_TEST_DATA/csv/aggregate_test_100.csv], has_header=true, limit=None, projection=[c1, c2]\
               \n"
diff --git a/datafusion/core/tests/sql/subqueries.rs b/datafusion/core/tests/sql/subqueries.rs
index 4eaf921f6..d85a26932 100644
--- a/datafusion/core/tests/sql/subqueries.rs
+++ b/datafusion/core/tests/sql/subqueries.rs
@@ -147,8 +147,8 @@ order by s_acctbal desc, n_name, s_name, p_partkey;"#;
           Inner Join: #supplier.s_nationkey = #nation.n_nationkey
             Inner Join: #partsupp.ps_suppkey = #supplier.s_suppkey
               Inner Join: #part.p_partkey = #partsupp.ps_partkey
-                Filter: #part.p_size = Int64(15) AND #part.p_type LIKE Utf8("%BRASS")
-                  TableScan: part projection=[p_partkey, p_mfgr, p_type, p_size], partial_filters=[#part.p_size = Int64(15), #part.p_type LIKE Utf8("%BRASS")]
+                Filter: #part.p_size = Int32(15) AND #part.p_type LIKE Utf8("%BRASS")
+                  TableScan: part projection=[p_partkey, p_mfgr, p_type, p_size], partial_filters=[#part.p_size = Int32(15), #part.p_type LIKE Utf8("%BRASS")]
                 TableScan: partsupp projection=[ps_partkey, ps_suppkey, ps_supplycost]
               TableScan: supplier projection=[s_suppkey, s_name, s_address, s_nationkey, s_phone, s_acctbal, s_comment]
             TableScan: nation projection=[n_nationkey, n_name, n_regionkey]
diff --git a/datafusion/optimizer/src/lib.rs b/datafusion/optimizer/src/lib.rs
index 6da67b6fc..60c450992 100644
--- a/datafusion/optimizer/src/lib.rs
+++ b/datafusion/optimizer/src/lib.rs
@@ -33,6 +33,7 @@ pub mod single_distinct_to_groupby;
 pub mod subquery_filter_to_join;
 pub mod utils;
 
+pub mod pre_cast_lit_in_comparison;
 pub mod rewrite_disjunctive_predicate;
 #[cfg(test)]
 pub mod test;
diff --git a/datafusion/optimizer/src/pre_cast_lit_in_comparison.rs b/datafusion/optimizer/src/pre_cast_lit_in_comparison.rs
new file mode 100644
index 000000000..0c16f7921
--- /dev/null
+++ b/datafusion/optimizer/src/pre_cast_lit_in_comparison.rs
@@ -0,0 +1,311 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+//! Pre-cast literal binary comparison rule can be only used to the binary comparison expr.
+//! It can reduce adding the `Expr::Cast` to the expr instead of adding the `Expr::Cast` to literal expr.
+use crate::{OptimizerConfig, OptimizerRule};
+use arrow::datatypes::DataType;
+use datafusion_common::{DFSchemaRef, DataFusionError, Result, ScalarValue};
+use datafusion_expr::utils::from_plan;
+use datafusion_expr::{binary_expr, lit, Expr, ExprSchemable, LogicalPlan, Operator};
+
+/// The rule can be only used to the numeric binary comparison with literal expr, like below pattern:
+/// `left_expr comparison_op literal_expr` or `literal_expr comparison_op right_expr`.
+/// The data type of two sides must be signed numeric type now, and will support more data type later.
+///
+/// If the binary comparison expr match above rules, the optimizer will check if the value of `literal`
+/// is in within range(min,max) which is the range(min,max) of the data type for `left_expr` or `right_expr`.
+///
+/// If this true, the literal expr will be casted to the data type of expr on the other side, and the result of
+/// binary comparison will be `left_expr comparison_op cast(literal_expr, left_data_type)` or
+/// `cast(literal_expr, right_data_type) comparison_op right_expr`. For better optimization,
+/// the expr of `cast(literal_expr, target_type)` will be precomputed and converted to the new expr `new_literal_expr`
+/// which data type is `target_type`.
+/// If this false, do nothing.
+///
+/// This is inspired by the optimizer rule `UnwrapCastInBinaryComparison` of Spark.
+/// # Example
+///
+/// `Filter: c1 > INT64(10)` will be optimized to `Filter: c1 > CAST(INT64(10) AS INT32),
+/// and continue to be converted to `Filter: c1 > INT32(10)`, if the DataType of c1 is INT32.
+///
+#[derive(Default)]
+pub struct PreCastLitInComparisonExpressions {}
+
+impl PreCastLitInComparisonExpressions {
+    pub fn new() -> Self {
+        Self::default()
+    }
+}
+
+impl OptimizerRule for PreCastLitInComparisonExpressions {
+    fn optimize(
+        &self,
+        plan: &LogicalPlan,
+        _optimizer_config: &mut OptimizerConfig,
+    ) -> Result<LogicalPlan> {
+        optimize(plan)
+    }
+
+    fn name(&self) -> &str {
+        "pre_cast_lit_in_comparison"
+    }
+}
+
+fn optimize(plan: &LogicalPlan) -> Result<LogicalPlan> {
+    let new_inputs = plan
+        .inputs()
+        .iter()
+        .map(|input| optimize(input))
+        .collect::<Result<Vec<_>>>()?;
+
+    let schema = plan.schema();
+    let new_exprs = plan
+        .expressions()
+        .into_iter()
+        .map(|expr| visit_expr(expr, schema))
+        .collect::<Result<Vec<_>>>()?;
+
+    from_plan(plan, new_exprs.as_slice(), new_inputs.as_slice())
+}
+
+// Visit all type of expr, if the current has child expr, the child expr needed to visit first.
+fn visit_expr(expr: Expr, schema: &DFSchemaRef) -> Result<Expr> {
+    // traverse the expr by dfs
+    match &expr {
+        Expr::BinaryExpr { left, op, right } => {
+            // dfs visit the left and right expr
+            let left = visit_expr(*left.clone(), schema)?;
+            let right = visit_expr(*right.clone(), schema)?;
+            let left_type = left.get_type(schema);
+            let right_type = right.get_type(schema);
+            // can't get the data type, just return the expr
+            if left_type.is_err() || right_type.is_err() {
+                return Ok(expr.clone());
+            }
+            let left_type = left_type.unwrap();
+            let right_type = right_type.unwrap();
+            if !left_type.eq(&right_type)
+                && is_support_data_type(&left_type)
+                && is_support_data_type(&right_type)
+                && is_comparison_op(op)
+            {
+                match (&left, &right) {
+                    (Expr::Literal(_), Expr::Literal(_)) => {
+                        // do nothing
+                    }
+                    (Expr::Literal(left_lit_value), _)
+                        if can_integer_literal_cast_to_type(
+                            left_lit_value,
+                            &right_type,
+                        )? =>
+                    {
+                        // cast the left literal to the right type
+                        return Ok(binary_expr(
+                            cast_to_other_scalar_expr(left_lit_value, &right_type)?,
+                            *op,
+                            right,
+                        ));
+                    }
+                    (_, Expr::Literal(right_lit_value))
+                        if can_integer_literal_cast_to_type(
+                            right_lit_value,
+                            &left_type,
+                        )
+                        .unwrap() =>
+                    {
+                        // cast the right literal to the left type
+                        return Ok(binary_expr(
+                            left,
+                            *op,
+                            cast_to_other_scalar_expr(right_lit_value, &left_type)?,
+                        ));
+                    }
+                    (_, _) => {
+                        // do nothing
+                    }
+                };
+            }
+            // return the new binary op
+            Ok(binary_expr(left, *op, right))
+        }
+        // TODO: optimize in list
+        // Expr::InList { .. } => {}
+        // TODO: handle other expr type and dfs visit them
+        _ => Ok(expr),
+    }
+}
+
+fn cast_to_other_scalar_expr(
+    origin_value: &ScalarValue,
+    target_type: &DataType,
+) -> Result<Expr> {
+    // null case
+    if origin_value.is_null() {
+        // if the origin value is null, just convert to another type of null value
+        // The target type must be satisfied `is_support_data_type` method, we can unwrap safely
+        return Ok(lit(ScalarValue::try_from(target_type).unwrap()));
+    }
+    // no null case
+    let value: i64 = match origin_value {
+        ScalarValue::Int8(Some(v)) => *v as i64,
+        ScalarValue::Int16(Some(v)) => *v as i64,
+        ScalarValue::Int32(Some(v)) => *v as i64,
+        ScalarValue::Int64(Some(v)) => *v as i64,
+        other_value => {
+            return Err(DataFusionError::Internal(format!(
+                "Invalid type and value {}",
+                other_value
+            )))
+        }
+    };
+    Ok(lit(match target_type {
+        DataType::Int8 => ScalarValue::Int8(Some(value as i8)),
+        DataType::Int16 => ScalarValue::Int16(Some(value as i16)),
+        DataType::Int32 => ScalarValue::Int32(Some(value as i32)),
+        DataType::Int64 => ScalarValue::Int64(Some(value)),
+        other_type => {
+            return Err(DataFusionError::Internal(format!(
+                "Invalid target data type {:?}",
+                other_type
+            )))
+        }
+    }))
+}
+
+fn is_comparison_op(op: &Operator) -> bool {
+    matches!(
+        op,
+        Operator::Eq
+            | Operator::NotEq
+            | Operator::Gt
+            | Operator::GtEq
+            | Operator::Lt
+            | Operator::LtEq
+    )
+}
+
+fn is_support_data_type(data_type: &DataType) -> bool {
+    // TODO support decimal with other data type
+    matches!(
+        data_type,
+        DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64
+    )
+}
+
+fn can_integer_literal_cast_to_type(
+    integer_lit_value: &ScalarValue,
+    target_type: &DataType,
+) -> Result<bool> {
+    if integer_lit_value.is_null() {
+        // null value can be cast to any type of null value
+        return Ok(true);
+    }
+    let (target_min, target_max) = match target_type {
+        DataType::Int8 => (i8::MIN as i128, i8::MAX as i128),
+        DataType::Int16 => (i16::MIN as i128, i16::MAX as i128),
+        DataType::Int32 => (i32::MIN as i128, i32::MAX as i128),
+        DataType::Int64 => (i64::MIN as i128, i64::MAX as i128),
+        other_type => {
+            return Err(DataFusionError::Internal(format!(
+                "Error target data type {:?}",
+                other_type
+            )))
+        }
+    };
+    let lit_value = match integer_lit_value {
+        ScalarValue::Int8(Some(v)) => *v as i128,
+        ScalarValue::Int16(Some(v)) => *v as i128,
+        ScalarValue::Int32(Some(v)) => *v as i128,
+        ScalarValue::Int64(Some(v)) => *v as i128,
+        other_value => {
+            return Err(DataFusionError::Internal(format!(
+                "Invalid literal value {:?}",
+                other_value
+            )))
+        }
+    };
+
+    Ok(lit_value >= target_min && lit_value <= target_max)
+}
+
+#[cfg(test)]
+mod tests {
+    use crate::pre_cast_lit_in_comparison::visit_expr;
+    use arrow::datatypes::DataType;
+    use datafusion_common::{DFField, DFSchema, DFSchemaRef, ScalarValue};
+    use datafusion_expr::{col, lit, Expr};
+    use std::collections::HashMap;
+    use std::sync::Arc;
+
+    #[test]
+    fn test_not_cast_lit_comparison() {
+        let schema = expr_test_schema();
+        // INT8(NULL) < INT32(12)
+        let lit_lt_lit =
+            lit(ScalarValue::Int8(None)).lt(lit(ScalarValue::Int32(Some(12))));
+        assert_eq!(optimize_test(lit_lt_lit.clone(), &schema), lit_lt_lit);
+        // INT32(c1) > INT64(c2)
+        let c1_gt_c2 = col("c1").gt(col("c2"));
+        assert_eq!(optimize_test(c1_gt_c2.clone(), &schema), c1_gt_c2);
+
+        // INT32(c1) < INT32(16), the type is same
+        let expr_lt = col("c1").lt(lit(ScalarValue::Int32(Some(16))));
+        assert_eq!(optimize_test(expr_lt.clone(), &schema), expr_lt);
+
+        // the 99999999999 is not within the range of MAX(int32) and MIN(int32), we don't cast the lit(99999999999) to int32 type
+        let expr_lt = col("c1").lt(lit(ScalarValue::Int64(Some(99999999999))));
+        assert_eq!(optimize_test(expr_lt.clone(), &schema), expr_lt);
+    }
+
+    #[test]
+    fn test_pre_cast_lit_comparison() {
+        let schema = expr_test_schema();
+        // c1 < INT64(16) -> c1 < cast(INT32(16))
+        // the 16 is within the range of MAX(int32) and MIN(int32), we can cast the 16 to int32(16)
+        let expr_lt = col("c1").lt(lit(ScalarValue::Int64(Some(16))));
+        let expected = col("c1").lt(lit(ScalarValue::Int32(Some(16))));
+        assert_eq!(optimize_test(expr_lt, &schema), expected);
+
+        // INT64(c2) = INT32(16) => INT64(c2) = INT64(16)
+        let c2_eq_lit = col("c2").eq(lit(ScalarValue::Int32(Some(16))));
+        let expected = col("c2").eq(lit(ScalarValue::Int64(Some(16))));
+        assert_eq!(optimize_test(c2_eq_lit, &schema), expected);
+
+        // INT32(c1) < INT64(NULL) => INT32(c1) < INT32(NULL)
+        let c1_lt_lit_null = col("c1").lt(lit(ScalarValue::Int64(None)));
+        let expected = col("c1").lt(lit(ScalarValue::Int32(None)));
+        assert_eq!(optimize_test(c1_lt_lit_null, &schema), expected);
+    }
+
+    fn optimize_test(expr: Expr, schema: &DFSchemaRef) -> Expr {
+        visit_expr(expr, schema).unwrap()
+    }
+
+    fn expr_test_schema() -> DFSchemaRef {
+        Arc::new(
+            DFSchema::new_with_metadata(
+                vec![
+                    DFField::new(None, "c1", DataType::Int32, false),
+                    DFField::new(None, "c2", DataType::Int64, false),
+                ],
+                HashMap::new(),
+            )
+            .unwrap(),
+        )
+    }
+}