You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by ja...@apache.org on 2023/04/08 10:19:11 UTC

[arrow-datafusion] branch main updated: fix: type coercion for expr/subquery in InSubquery (#5883)

This is an automated email from the ASF dual-hosted git repository.

jakevin pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new bc0af737bc fix: type coercion for expr/subquery in InSubquery (#5883)
bc0af737bc is described below

commit bc0af737bce3c4283aad2b94752dbf5125754bf1
Author: jakevin <ja...@gmail.com>
AuthorDate: Sat Apr 8 18:19:03 2023 +0800

    fix: type coercion for expr/subquery in InSubquery (#5883)
    
    * fix: type coercion for expr/subquery in InSubquery
    
    * fix review and add new UT
---
 datafusion/expr/src/expr_schema.rs                 |  30 ++--
 datafusion/optimizer/src/analyzer/type_coercion.rs | 185 +++++++++++----------
 2 files changed, 111 insertions(+), 104 deletions(-)

diff --git a/datafusion/expr/src/expr_schema.rs b/datafusion/expr/src/expr_schema.rs
index ff6d9f5057..685d348ced 100644
--- a/datafusion/expr/src/expr_schema.rs
+++ b/datafusion/expr/src/expr_schema.rs
@@ -302,22 +302,13 @@ impl ExprSchemable for Expr {
         // like all of the binary expressions below. Perhaps Expr should track the
         // type of the expression?
 
-        // TODO(jackwener): Handle subqueries separately, need to refactor it.
-        match self {
-            Expr::ScalarSubquery(subquery) => {
-                return Ok(Expr::ScalarSubquery(cast_subquery(subquery, cast_to_type)?));
-            }
-            Expr::Exists { subquery, negated } => {
-                return Ok(Expr::Exists {
-                    subquery: cast_subquery(subquery, cast_to_type)?,
-                    negated,
-                });
-            }
-            _ => {}
-        }
-
         if can_cast_types(&this_type, cast_to_type) {
-            Ok(Expr::Cast(Cast::new(Box::new(self), cast_to_type.clone())))
+            match self {
+                Expr::ScalarSubquery(subquery) => {
+                    Ok(Expr::ScalarSubquery(cast_subquery(subquery, cast_to_type)?))
+                }
+                _ => Ok(Expr::Cast(Cast::new(Box::new(self), cast_to_type.clone()))),
+            }
         } else {
             Err(DataFusionError::Plan(format!(
                 "Cannot automatically convert {this_type:?} to {cast_to_type:?}"
@@ -326,7 +317,12 @@ impl ExprSchemable for Expr {
     }
 }
 
-fn cast_subquery(subquery: Subquery, cast_to_type: &DataType) -> Result<Subquery> {
+/// cast subquery in InSubquery/ScalarSubquery to a given type.
+pub fn cast_subquery(subquery: Subquery, cast_to_type: &DataType) -> Result<Subquery> {
+    if subquery.subquery.schema().field(0).data_type() == cast_to_type {
+        return Ok(subquery);
+    }
+
     let plan = subquery.subquery.as_ref();
     let new_plan = match plan {
         LogicalPlan::Projection(projection) => {
@@ -343,7 +339,7 @@ fn cast_subquery(subquery: Subquery, cast_to_type: &DataType) -> Result<Subquery
                 .cast_to(cast_to_type, subquery.subquery.schema())?;
             LogicalPlan::Projection(Projection::try_new(
                 vec![cast_expr],
-                subquery.subquery.clone(),
+                subquery.subquery,
             )?)
         }
     };
diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs b/datafusion/optimizer/src/analyzer/type_coercion.rs
index 0038aec933..952cc04b0b 100644
--- a/datafusion/optimizer/src/analyzer/type_coercion.rs
+++ b/datafusion/optimizer/src/analyzer/type_coercion.rs
@@ -25,6 +25,7 @@ use datafusion_common::config::ConfigOptions;
 use datafusion_common::tree_node::{RewriteRecursion, TreeNodeRewriter};
 use datafusion_common::{DFSchema, DFSchemaRef, DataFusionError, Result, ScalarValue};
 use datafusion_expr::expr::{self, Between, BinaryExpr, Case, Like, WindowFunction};
+use datafusion_expr::expr_schema::cast_subquery;
 use datafusion_expr::logical_plan::Subquery;
 use datafusion_expr::type_coercion::binary::{
     coerce_types, comparison_coercion, like_coercion,
@@ -147,20 +148,21 @@ impl TreeNodeRewriter for TypeCoercionRewriter {
                 subquery,
                 negated,
             } => {
-                let expr_type = expr.get_type(&self.schema)?;
                 let new_plan = analyze_internal(&self.schema, &subquery.subquery)?;
+                let expr_type = expr.get_type(&self.schema)?;
                 let subquery_type = new_plan.schema().field(0).data_type();
-                let expr = if &expr_type == subquery_type {
-                    expr
-                } else {
-                    Box::new(expr.cast_to(subquery_type, &self.schema)?)
+                let common_type = comparison_coercion(&expr_type, subquery_type).ok_or(DataFusionError::Plan(
+                    format!(
+                        "expr type {expr_type:?} can't cast to {subquery_type:?} in InSubquery"
+                    ),
+                ))?;
+                let new_subquery = Subquery {
+                    subquery: Arc::new(new_plan),
+                    outer_ref_columns: subquery.outer_ref_columns,
                 };
                 Ok(Expr::InSubquery {
-                    expr,
-                    subquery: Subquery {
-                        subquery: Arc::new(new_plan),
-                        outer_ref_columns: subquery.outer_ref_columns,
-                    },
+                    expr: Box::new(expr.cast_to(&common_type, &self.schema)?),
+                    subquery: cast_subquery(new_subquery, &common_type)?,
                     negated,
                 })
             }
@@ -749,16 +751,30 @@ mod test {
     };
     use crate::test::assert_analyzed_plan_eq;
 
+    fn empty() -> Arc<LogicalPlan> {
+        Arc::new(LogicalPlan::EmptyRelation(EmptyRelation {
+            produce_one_row: false,
+            schema: Arc::new(DFSchema::empty()),
+        }))
+    }
+
+    fn empty_with_type(data_type: DataType) -> Arc<LogicalPlan> {
+        Arc::new(LogicalPlan::EmptyRelation(EmptyRelation {
+            produce_one_row: false,
+            schema: Arc::new(
+                DFSchema::new_with_metadata(
+                    vec![DFField::new_unqualified("a", data_type, true)],
+                    std::collections::HashMap::new(),
+                )
+                .unwrap(),
+            ),
+        }))
+    }
+
     #[test]
     fn simple_case() -> Result<()> {
         let expr = col("a").lt(lit(2_u32));
-        let empty = Arc::new(LogicalPlan::EmptyRelation(EmptyRelation {
-            produce_one_row: false,
-            schema: Arc::new(DFSchema::new_with_metadata(
-                vec![DFField::new_unqualified("a", DataType::Float64, true)],
-                std::collections::HashMap::new(),
-            )?),
-        }));
+        let empty = empty_with_type(DataType::Float64);
         let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?);
         let expected = "Projection: a < CAST(UInt32(2) AS Float64)\n  EmptyRelation";
         assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), &plan, expected)
@@ -767,13 +783,8 @@ mod test {
     #[test]
     fn nested_case() -> Result<()> {
         let expr = col("a").lt(lit(2_u32));
-        let empty = Arc::new(LogicalPlan::EmptyRelation(EmptyRelation {
-            produce_one_row: false,
-            schema: Arc::new(DFSchema::new_with_metadata(
-                vec![DFField::new_unqualified("a", DataType::Float64, true)],
-                std::collections::HashMap::new(),
-            )?),
-        }));
+        let empty = empty_with_type(DataType::Float64);
+
         let plan = LogicalPlan::Projection(Projection::try_new(
             vec![expr.clone().or(expr)],
             empty,
@@ -952,10 +963,7 @@ mod test {
         //CAST(Utf8("1998-03-18") AS Date32) + IntervalDayTime("386547056640")
         let expr = cast(lit("1998-03-18"), DataType::Date32)
             + lit(ScalarValue::IntervalDayTime(Some(386547056640)));
-        let empty = Arc::new(LogicalPlan::EmptyRelation(EmptyRelation {
-            produce_one_row: false,
-            schema: Arc::new(DFSchema::empty()),
-        }));
+        let empty = empty();
         let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?);
         let expected =
             "Projection: CAST(Utf8(\"1998-03-18\") AS Date32) + IntervalDayTime(\"386547056640\")\n  EmptyRelation";
@@ -967,13 +975,7 @@ mod test {
     fn inlist_case() -> Result<()> {
         // a in (1,4,8), a is int64
         let expr = col("a").in_list(vec![lit(1_i32), lit(4_i8), lit(8_i64)], false);
-        let empty = Arc::new(LogicalPlan::EmptyRelation(EmptyRelation {
-            produce_one_row: false,
-            schema: Arc::new(DFSchema::new_with_metadata(
-                vec![DFField::new_unqualified("a", DataType::Int64, true)],
-                std::collections::HashMap::new(),
-            )?),
-        }));
+        let empty = empty_with_type(DataType::Int64);
         let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?);
         let expected =
             "Projection: a IN ([CAST(Int32(1) AS Int64), CAST(Int8(4) AS Int64), Int64(8)]) AS a IN (Map { iter: Iter([Int32(1), Int8(4), Int64(8)]) })\
@@ -1158,26 +1160,6 @@ mod test {
         Ok(())
     }
 
-    fn empty() -> Arc<LogicalPlan> {
-        Arc::new(LogicalPlan::EmptyRelation(EmptyRelation {
-            produce_one_row: false,
-            schema: Arc::new(DFSchema::empty()),
-        }))
-    }
-
-    fn empty_with_type(data_type: DataType) -> Arc<LogicalPlan> {
-        Arc::new(LogicalPlan::EmptyRelation(EmptyRelation {
-            produce_one_row: false,
-            schema: Arc::new(
-                DFSchema::new_with_metadata(
-                    vec![DFField::new_unqualified("a", data_type, true)],
-                    std::collections::HashMap::new(),
-                )
-                .unwrap(),
-            ),
-        }))
-    }
-
     #[test]
     fn test_type_coercion_rewrite() -> Result<()> {
         // gt
@@ -1223,10 +1205,7 @@ mod test {
             DataType::Timestamp(arrow::datatypes::TimeUnit::Nanosecond, None),
         )
         .eq(cast(lit("1998-03-18"), DataType::Date32));
-        let empty = Arc::new(LogicalPlan::EmptyRelation(EmptyRelation {
-            produce_one_row: false,
-            schema: Arc::new(DFSchema::empty()),
-        }));
+        let empty = empty();
         let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?);
         dbg!(&plan);
         let expected =
@@ -1392,10 +1371,7 @@ mod test {
                 DataType::Timestamp(arrow::datatypes::TimeUnit::Nanosecond, None),
             )),
         ));
-        let empty = Arc::new(LogicalPlan::EmptyRelation(EmptyRelation {
-            produce_one_row: false,
-            schema: Arc::new(DFSchema::empty()),
-        }));
+        let empty = empty();
         let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?);
         let expected = "Projection: IntervalYearMonth(\"12\") + CAST(Utf8(\"2000-01-01T00:00:00\") AS Timestamp(Nanosecond, None))\n  EmptyRelation";
         assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), &plan, expected)?;
@@ -1415,10 +1391,7 @@ mod test {
                 DataType::Timestamp(arrow::datatypes::TimeUnit::Nanosecond, None),
             )),
         ));
-        let empty = Arc::new(LogicalPlan::EmptyRelation(EmptyRelation {
-            produce_one_row: false,
-            schema: Arc::new(DFSchema::empty()),
-        }));
+        let empty = empty();
         let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?);
         dbg!(&plan);
         let expected =
@@ -1428,37 +1401,75 @@ mod test {
     }
 
     #[test]
-    fn in_subquery() -> Result<()> {
-        let empty_inside = Arc::new(LogicalPlan::EmptyRelation(EmptyRelation {
-            produce_one_row: false,
-            schema: Arc::new(DFSchema::new_with_metadata(
-                vec![DFField::new_unqualified("a_int32", DataType::Int32, true)],
-                std::collections::HashMap::new(),
-            )?),
-        }));
-        let empty_outside = Arc::new(LogicalPlan::EmptyRelation(EmptyRelation {
-            produce_one_row: false,
-            schema: Arc::new(DFSchema::new_with_metadata(
-                vec![DFField::new_unqualified("a_int64", DataType::Int64, true)],
-                std::collections::HashMap::new(),
-            )?),
-        }));
+    fn in_subquery_cast_subquery() -> Result<()> {
+        let empty_int32 = empty_with_type(DataType::Int32);
+        let empty_int64 = empty_with_type(DataType::Int64);
+
         let in_subquery_expr = Expr::InSubquery {
-            expr: Box::new(col("a_int64")),
+            expr: Box::new(col("a")),
             subquery: Subquery {
-                subquery: empty_inside,
+                subquery: empty_int32,
                 outer_ref_columns: vec![],
             },
             negated: false,
         };
-        let plan = LogicalPlan::Filter(Filter::try_new(in_subquery_expr, empty_outside)?);
-        // add cast for
+        let plan = LogicalPlan::Filter(Filter::try_new(in_subquery_expr, empty_int64)?);
+        // add cast for subquery
+        let expected = "\
+        Filter: a IN (<subquery>)\
+        \n  Subquery:\
+        \n    Projection: CAST(a AS Int64)\
+        \n      EmptyRelation\
+        \n  EmptyRelation";
+        assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), &plan, expected)?;
+        Ok(())
+    }
+
+    #[test]
+    fn in_subquery_cast_expr() -> Result<()> {
+        let empty_int32 = empty_with_type(DataType::Int32);
+        let empty_int64 = empty_with_type(DataType::Int64);
+
+        let in_subquery_expr = Expr::InSubquery {
+            expr: Box::new(col("a")),
+            subquery: Subquery {
+                subquery: empty_int64,
+                outer_ref_columns: vec![],
+            },
+            negated: false,
+        };
+        let plan = LogicalPlan::Filter(Filter::try_new(in_subquery_expr, empty_int32)?);
+        // add cast for subquery
         let expected = "\
-        Filter: CAST(a_int64 AS Int32) IN (<subquery>)\
+        Filter: CAST(a AS Int64) IN (<subquery>)\
         \n  Subquery:\
         \n    EmptyRelation\
         \n  EmptyRelation";
         assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), &plan, expected)?;
         Ok(())
     }
+
+    #[test]
+    fn in_subquery_cast_all() -> Result<()> {
+        let empty_inside = empty_with_type(DataType::Decimal128(10, 5));
+        let empty_outside = empty_with_type(DataType::Decimal128(8, 8));
+
+        let in_subquery_expr = Expr::InSubquery {
+            expr: Box::new(col("a")),
+            subquery: Subquery {
+                subquery: empty_inside,
+                outer_ref_columns: vec![],
+            },
+            negated: false,
+        };
+        let plan = LogicalPlan::Filter(Filter::try_new(in_subquery_expr, empty_outside)?);
+        // add cast for subquery
+        let expected = "Filter: CAST(a AS Decimal128(13, 8)) IN (<subquery>)\
+        \n  Subquery:\
+        \n    Projection: CAST(a AS Decimal128(13, 8))\
+        \n      EmptyRelation\
+        \n  EmptyRelation";
+        assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), &plan, expected)?;
+        Ok(())
+    }
 }