You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by ja...@apache.org on 2023/04/08 10:19:11 UTC
[arrow-datafusion] branch main updated: fix: type coercion for expr/subquery in InSubquery (#5883)
This is an automated email from the ASF dual-hosted git repository.
jakevin pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new bc0af737bc fix: type coercion for expr/subquery in InSubquery (#5883)
bc0af737bc is described below
commit bc0af737bce3c4283aad2b94752dbf5125754bf1
Author: jakevin <ja...@gmail.com>
AuthorDate: Sat Apr 8 18:19:03 2023 +0800
fix: type coercion for expr/subquery in InSubquery (#5883)
* fix: type coercion for expr/subquery in InSubquery
* fix review and add new UT
---
datafusion/expr/src/expr_schema.rs | 30 ++--
datafusion/optimizer/src/analyzer/type_coercion.rs | 185 +++++++++++----------
2 files changed, 111 insertions(+), 104 deletions(-)
diff --git a/datafusion/expr/src/expr_schema.rs b/datafusion/expr/src/expr_schema.rs
index ff6d9f5057..685d348ced 100644
--- a/datafusion/expr/src/expr_schema.rs
+++ b/datafusion/expr/src/expr_schema.rs
@@ -302,22 +302,13 @@ impl ExprSchemable for Expr {
// like all of the binary expressions below. Perhaps Expr should track the
// type of the expression?
- // TODO(jackwener): Handle subqueries separately, need to refactor it.
- match self {
- Expr::ScalarSubquery(subquery) => {
- return Ok(Expr::ScalarSubquery(cast_subquery(subquery, cast_to_type)?));
- }
- Expr::Exists { subquery, negated } => {
- return Ok(Expr::Exists {
- subquery: cast_subquery(subquery, cast_to_type)?,
- negated,
- });
- }
- _ => {}
- }
-
if can_cast_types(&this_type, cast_to_type) {
- Ok(Expr::Cast(Cast::new(Box::new(self), cast_to_type.clone())))
+ match self {
+ Expr::ScalarSubquery(subquery) => {
+ Ok(Expr::ScalarSubquery(cast_subquery(subquery, cast_to_type)?))
+ }
+ _ => Ok(Expr::Cast(Cast::new(Box::new(self), cast_to_type.clone()))),
+ }
} else {
Err(DataFusionError::Plan(format!(
"Cannot automatically convert {this_type:?} to {cast_to_type:?}"
@@ -326,7 +317,12 @@ impl ExprSchemable for Expr {
}
}
-fn cast_subquery(subquery: Subquery, cast_to_type: &DataType) -> Result<Subquery> {
+/// cast subquery in InSubquery/ScalarSubquery to a given type.
+pub fn cast_subquery(subquery: Subquery, cast_to_type: &DataType) -> Result<Subquery> {
+ if subquery.subquery.schema().field(0).data_type() == cast_to_type {
+ return Ok(subquery);
+ }
+
let plan = subquery.subquery.as_ref();
let new_plan = match plan {
LogicalPlan::Projection(projection) => {
@@ -343,7 +339,7 @@ fn cast_subquery(subquery: Subquery, cast_to_type: &DataType) -> Result<Subquery
.cast_to(cast_to_type, subquery.subquery.schema())?;
LogicalPlan::Projection(Projection::try_new(
vec![cast_expr],
- subquery.subquery.clone(),
+ subquery.subquery,
)?)
}
};
diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs b/datafusion/optimizer/src/analyzer/type_coercion.rs
index 0038aec933..952cc04b0b 100644
--- a/datafusion/optimizer/src/analyzer/type_coercion.rs
+++ b/datafusion/optimizer/src/analyzer/type_coercion.rs
@@ -25,6 +25,7 @@ use datafusion_common::config::ConfigOptions;
use datafusion_common::tree_node::{RewriteRecursion, TreeNodeRewriter};
use datafusion_common::{DFSchema, DFSchemaRef, DataFusionError, Result, ScalarValue};
use datafusion_expr::expr::{self, Between, BinaryExpr, Case, Like, WindowFunction};
+use datafusion_expr::expr_schema::cast_subquery;
use datafusion_expr::logical_plan::Subquery;
use datafusion_expr::type_coercion::binary::{
coerce_types, comparison_coercion, like_coercion,
@@ -147,20 +148,21 @@ impl TreeNodeRewriter for TypeCoercionRewriter {
subquery,
negated,
} => {
- let expr_type = expr.get_type(&self.schema)?;
let new_plan = analyze_internal(&self.schema, &subquery.subquery)?;
+ let expr_type = expr.get_type(&self.schema)?;
let subquery_type = new_plan.schema().field(0).data_type();
- let expr = if &expr_type == subquery_type {
- expr
- } else {
- Box::new(expr.cast_to(subquery_type, &self.schema)?)
+ let common_type = comparison_coercion(&expr_type, subquery_type).ok_or(DataFusionError::Plan(
+ format!(
+ "expr type {expr_type:?} can't cast to {subquery_type:?} in InSubquery"
+ ),
+ ))?;
+ let new_subquery = Subquery {
+ subquery: Arc::new(new_plan),
+ outer_ref_columns: subquery.outer_ref_columns,
};
Ok(Expr::InSubquery {
- expr,
- subquery: Subquery {
- subquery: Arc::new(new_plan),
- outer_ref_columns: subquery.outer_ref_columns,
- },
+ expr: Box::new(expr.cast_to(&common_type, &self.schema)?),
+ subquery: cast_subquery(new_subquery, &common_type)?,
negated,
})
}
@@ -749,16 +751,30 @@ mod test {
};
use crate::test::assert_analyzed_plan_eq;
+ fn empty() -> Arc<LogicalPlan> {
+ Arc::new(LogicalPlan::EmptyRelation(EmptyRelation {
+ produce_one_row: false,
+ schema: Arc::new(DFSchema::empty()),
+ }))
+ }
+
+ fn empty_with_type(data_type: DataType) -> Arc<LogicalPlan> {
+ Arc::new(LogicalPlan::EmptyRelation(EmptyRelation {
+ produce_one_row: false,
+ schema: Arc::new(
+ DFSchema::new_with_metadata(
+ vec![DFField::new_unqualified("a", data_type, true)],
+ std::collections::HashMap::new(),
+ )
+ .unwrap(),
+ ),
+ }))
+ }
+
#[test]
fn simple_case() -> Result<()> {
let expr = col("a").lt(lit(2_u32));
- let empty = Arc::new(LogicalPlan::EmptyRelation(EmptyRelation {
- produce_one_row: false,
- schema: Arc::new(DFSchema::new_with_metadata(
- vec![DFField::new_unqualified("a", DataType::Float64, true)],
- std::collections::HashMap::new(),
- )?),
- }));
+ let empty = empty_with_type(DataType::Float64);
let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?);
let expected = "Projection: a < CAST(UInt32(2) AS Float64)\n EmptyRelation";
assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), &plan, expected)
@@ -767,13 +783,8 @@ mod test {
#[test]
fn nested_case() -> Result<()> {
let expr = col("a").lt(lit(2_u32));
- let empty = Arc::new(LogicalPlan::EmptyRelation(EmptyRelation {
- produce_one_row: false,
- schema: Arc::new(DFSchema::new_with_metadata(
- vec![DFField::new_unqualified("a", DataType::Float64, true)],
- std::collections::HashMap::new(),
- )?),
- }));
+ let empty = empty_with_type(DataType::Float64);
+
let plan = LogicalPlan::Projection(Projection::try_new(
vec![expr.clone().or(expr)],
empty,
@@ -952,10 +963,7 @@ mod test {
//CAST(Utf8("1998-03-18") AS Date32) + IntervalDayTime("386547056640")
let expr = cast(lit("1998-03-18"), DataType::Date32)
+ lit(ScalarValue::IntervalDayTime(Some(386547056640)));
- let empty = Arc::new(LogicalPlan::EmptyRelation(EmptyRelation {
- produce_one_row: false,
- schema: Arc::new(DFSchema::empty()),
- }));
+ let empty = empty();
let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?);
let expected =
"Projection: CAST(Utf8(\"1998-03-18\") AS Date32) + IntervalDayTime(\"386547056640\")\n EmptyRelation";
@@ -967,13 +975,7 @@ mod test {
fn inlist_case() -> Result<()> {
// a in (1,4,8), a is int64
let expr = col("a").in_list(vec![lit(1_i32), lit(4_i8), lit(8_i64)], false);
- let empty = Arc::new(LogicalPlan::EmptyRelation(EmptyRelation {
- produce_one_row: false,
- schema: Arc::new(DFSchema::new_with_metadata(
- vec![DFField::new_unqualified("a", DataType::Int64, true)],
- std::collections::HashMap::new(),
- )?),
- }));
+ let empty = empty_with_type(DataType::Int64);
let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?);
let expected =
"Projection: a IN ([CAST(Int32(1) AS Int64), CAST(Int8(4) AS Int64), Int64(8)]) AS a IN (Map { iter: Iter([Int32(1), Int8(4), Int64(8)]) })\
@@ -1158,26 +1160,6 @@ mod test {
Ok(())
}
- fn empty() -> Arc<LogicalPlan> {
- Arc::new(LogicalPlan::EmptyRelation(EmptyRelation {
- produce_one_row: false,
- schema: Arc::new(DFSchema::empty()),
- }))
- }
-
- fn empty_with_type(data_type: DataType) -> Arc<LogicalPlan> {
- Arc::new(LogicalPlan::EmptyRelation(EmptyRelation {
- produce_one_row: false,
- schema: Arc::new(
- DFSchema::new_with_metadata(
- vec![DFField::new_unqualified("a", data_type, true)],
- std::collections::HashMap::new(),
- )
- .unwrap(),
- ),
- }))
- }
-
#[test]
fn test_type_coercion_rewrite() -> Result<()> {
// gt
@@ -1223,10 +1205,7 @@ mod test {
DataType::Timestamp(arrow::datatypes::TimeUnit::Nanosecond, None),
)
.eq(cast(lit("1998-03-18"), DataType::Date32));
- let empty = Arc::new(LogicalPlan::EmptyRelation(EmptyRelation {
- produce_one_row: false,
- schema: Arc::new(DFSchema::empty()),
- }));
+ let empty = empty();
let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?);
dbg!(&plan);
let expected =
@@ -1392,10 +1371,7 @@ mod test {
DataType::Timestamp(arrow::datatypes::TimeUnit::Nanosecond, None),
)),
));
- let empty = Arc::new(LogicalPlan::EmptyRelation(EmptyRelation {
- produce_one_row: false,
- schema: Arc::new(DFSchema::empty()),
- }));
+ let empty = empty();
let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?);
let expected = "Projection: IntervalYearMonth(\"12\") + CAST(Utf8(\"2000-01-01T00:00:00\") AS Timestamp(Nanosecond, None))\n EmptyRelation";
assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), &plan, expected)?;
@@ -1415,10 +1391,7 @@ mod test {
DataType::Timestamp(arrow::datatypes::TimeUnit::Nanosecond, None),
)),
));
- let empty = Arc::new(LogicalPlan::EmptyRelation(EmptyRelation {
- produce_one_row: false,
- schema: Arc::new(DFSchema::empty()),
- }));
+ let empty = empty();
let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?);
dbg!(&plan);
let expected =
@@ -1428,37 +1401,75 @@ mod test {
}
#[test]
- fn in_subquery() -> Result<()> {
- let empty_inside = Arc::new(LogicalPlan::EmptyRelation(EmptyRelation {
- produce_one_row: false,
- schema: Arc::new(DFSchema::new_with_metadata(
- vec![DFField::new_unqualified("a_int32", DataType::Int32, true)],
- std::collections::HashMap::new(),
- )?),
- }));
- let empty_outside = Arc::new(LogicalPlan::EmptyRelation(EmptyRelation {
- produce_one_row: false,
- schema: Arc::new(DFSchema::new_with_metadata(
- vec![DFField::new_unqualified("a_int64", DataType::Int64, true)],
- std::collections::HashMap::new(),
- )?),
- }));
+ fn in_subquery_cast_subquery() -> Result<()> {
+ let empty_int32 = empty_with_type(DataType::Int32);
+ let empty_int64 = empty_with_type(DataType::Int64);
+
let in_subquery_expr = Expr::InSubquery {
- expr: Box::new(col("a_int64")),
+ expr: Box::new(col("a")),
subquery: Subquery {
- subquery: empty_inside,
+ subquery: empty_int32,
outer_ref_columns: vec![],
},
negated: false,
};
- let plan = LogicalPlan::Filter(Filter::try_new(in_subquery_expr, empty_outside)?);
- // add cast for
+ let plan = LogicalPlan::Filter(Filter::try_new(in_subquery_expr, empty_int64)?);
+ // add cast for subquery
+ let expected = "\
+ Filter: a IN (<subquery>)\
+ \n Subquery:\
+ \n Projection: CAST(a AS Int64)\
+ \n EmptyRelation\
+ \n EmptyRelation";
+ assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), &plan, expected)?;
+ Ok(())
+ }
+
+ #[test]
+ fn in_subquery_cast_expr() -> Result<()> {
+ let empty_int32 = empty_with_type(DataType::Int32);
+ let empty_int64 = empty_with_type(DataType::Int64);
+
+ let in_subquery_expr = Expr::InSubquery {
+ expr: Box::new(col("a")),
+ subquery: Subquery {
+ subquery: empty_int64,
+ outer_ref_columns: vec![],
+ },
+ negated: false,
+ };
+ let plan = LogicalPlan::Filter(Filter::try_new(in_subquery_expr, empty_int32)?);
+ // add cast for subquery
let expected = "\
- Filter: CAST(a_int64 AS Int32) IN (<subquery>)\
+ Filter: CAST(a AS Int64) IN (<subquery>)\
\n Subquery:\
\n EmptyRelation\
\n EmptyRelation";
assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), &plan, expected)?;
Ok(())
}
+
+ #[test]
+ fn in_subquery_cast_all() -> Result<()> {
+ let empty_inside = empty_with_type(DataType::Decimal128(10, 5));
+ let empty_outside = empty_with_type(DataType::Decimal128(8, 8));
+
+ let in_subquery_expr = Expr::InSubquery {
+ expr: Box::new(col("a")),
+ subquery: Subquery {
+ subquery: empty_inside,
+ outer_ref_columns: vec![],
+ },
+ negated: false,
+ };
+ let plan = LogicalPlan::Filter(Filter::try_new(in_subquery_expr, empty_outside)?);
+ // add cast for subquery
+ let expected = "Filter: CAST(a AS Decimal128(13, 8)) IN (<subquery>)\
+ \n Subquery:\
+ \n Projection: CAST(a AS Decimal128(13, 8))\
+ \n EmptyRelation\
+ \n EmptyRelation";
+ assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), &plan, expected)?;
+ Ok(())
+ }
}