You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by al...@apache.org on 2022/12/20 22:15:27 UTC
[arrow-datafusion] branch master updated: Support type coercion for equijoin (#4666)
This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/master by this push:
new ac2e5d15e Support type coercion for equijoin (#4666)
ac2e5d15e is described below
commit ac2e5d15e5452e83c835d793a95335e87bf35569
Author: ygf11 <ya...@gmail.com>
AuthorDate: Wed Dec 21 06:15:21 2022 +0800
Support type coercion for equijoin (#4666)
* Support type coercion for equijoin
* fix cargo fmt
* add check for length of join expressions
* Update test for logical conflict
Co-authored-by: Andrew Lamb <an...@nerdnetworks.org>
---
datafusion/core/tests/sql/joins.rs | 172 ++++++++++++++++++++++++++++---
datafusion/expr/src/logical_plan/plan.rs | 5 +-
datafusion/expr/src/utils.rs | 30 ++++--
3 files changed, 183 insertions(+), 24 deletions(-)
diff --git a/datafusion/core/tests/sql/joins.rs b/datafusion/core/tests/sql/joins.rs
index 70b781399..9d7ddc526 100644
--- a/datafusion/core/tests/sql/joins.rs
+++ b/datafusion/core/tests/sql/joins.rs
@@ -1448,11 +1448,11 @@ async fn hash_join_with_decimal() -> Result<()> {
let state = ctx.state();
let plan = state.optimize(&plan)?;
let expected = vec![
- "Explain [plan_type:Utf8, plan:Utf8]",
- " Projection: t1.c1, t1.c2, t1.c3, t1.c4, t2.c1, t2.c2, t2.c3, t2.c4 [c1:Date32;N, c2:Date64;N, c3:Decimal128(5, 2);N, c4:Dictionary(Int32, Utf8);N, c1:Date32;N, c2:Date64;N, c3:Decimal128(10, 2);N, c4:Dictionary(Int32, Utf8);N]",
- " Right Join: t1.c3 = t2.c3 [c1:Date32;N, c2:Date64;N, c3:Decimal128(5, 2);N, c4:Dictionary(Int32, Utf8);N, c1:Date32;N, c2:Date64;N, c3:Decimal128(10, 2);N, c4:Dictionary(Int32, Utf8);N]",
- " TableScan: t1 projection=[c1, c2, c3, c4] [c1:Date32;N, c2:Date64;N, c3:Decimal128(5, 2);N, c4:Dictionary(Int32, Utf8);N]",
- " TableScan: t2 projection=[c1, c2, c3, c4] [c1:Date32;N, c2:Date64;N, c3:Decimal128(10, 2);N, c4:Dictionary(Int32, Utf8);N]",
+ "Explain [plan_type:Utf8, plan:Utf8]",
+ " Projection: t1.c1, t1.c2, t1.c3, t1.c4, t2.c1, t2.c2, t2.c3, t2.c4 [c1:Date32;N, c2:Date64;N, c3:Decimal128(5, 2);N, c4:Dictionary(Int32, Utf8);N, c1:Date32;N, c2:Date64;N, c3:Decimal128(10, 2);N, c4:Dictionary(Int32, Utf8);N]",
+ " Right Join: CAST(t1.c3 AS Decimal128(10, 2)) = t2.c3 [c1:Date32;N, c2:Date64;N, c3:Decimal128(5, 2);N, c4:Dictionary(Int32, Utf8);N, c1:Date32;N, c2:Date64;N, c3:Decimal128(10, 2);N, c4:Dictionary(Int32, Utf8);N]",
+ " TableScan: t1 projection=[c1, c2, c3, c4] [c1:Date32;N, c2:Date64;N, c3:Decimal128(5, 2);N, c4:Dictionary(Int32, Utf8);N]",
+ " TableScan: t2 projection=[c1, c2, c3, c4] [c1:Date32;N, c2:Date64;N, c3:Decimal128(10, 2);N, c4:Dictionary(Int32, Utf8);N]",
];
let formatted = plan.display_indent_schema().to_string();
let actual: Vec<&str> = formatted.trim().lines().collect();
@@ -1982,19 +1982,22 @@ async fn sort_merge_join_on_decimal() -> Result<()> {
let state = ctx.state();
let logical_plan = state.optimize(&plan)?;
let physical_plan = state.create_physical_plan(&logical_plan).await?;
+
let expected = vec![
"ProjectionExec: expr=[c1@0 as c1, c2@1 as c2, c3@2 as c3, c4@3 as c4, c1@4 as c1, c2@5 as c2, c3@6 as c3, c4@7 as c4]",
- " SortMergeJoin: join_type=Right, on=[(Column { name: \"c3\", index: 2 }, Column { name: \"c3\", index: 2 })]",
- " SortExec: [c3@2 ASC]",
- " CoalesceBatchesExec: target_batch_size=4096",
- " RepartitionExec: partitioning=Hash([Column { name: \"c3\", index: 2 }], 2)",
- " RepartitionExec: partitioning=RoundRobinBatch(2)",
- " MemoryExec: partitions=1, partition_sizes=[1]",
- " SortExec: [c3@2 ASC]",
- " CoalesceBatchesExec: target_batch_size=4096",
- " RepartitionExec: partitioning=Hash([Column { name: \"c3\", index: 2 }], 2)",
- " RepartitionExec: partitioning=RoundRobinBatch(2)",
- " MemoryExec: partitions=1, partition_sizes=[1]",
+ " ProjectionExec: expr=[c1@0 as c1, c2@1 as c2, c3@2 as c3, c4@3 as c4, c1@5 as c1, c2@6 as c2, c3@7 as c3, c4@8 as c4]",
+ " SortMergeJoin: join_type=Right, on=[(Column { name: \"CAST(t1.c3 AS Decimal128(10, 2))\", index: 4 }, Column { name: \"c3\", index: 2 })]",
+ " SortExec: [CAST(t1.c3 AS Decimal128(10, 2))@4 ASC]",
+ " CoalesceBatchesExec: target_batch_size=4096",
+ " RepartitionExec: partitioning=Hash([Column { name: \"CAST(t1.c3 AS Decimal128(10, 2))\", index: 4 }], 2)",
+ " ProjectionExec: expr=[c1@0 as c1, c2@1 as c2, c3@2 as c3, c4@3 as c4, CAST(c3@2 AS Decimal128(10, 2)) as CAST(t1.c3 AS Decimal128(10, 2))]",
+ " RepartitionExec: partitioning=RoundRobinBatch(2)",
+ " MemoryExec: partitions=1, partition_sizes=[1]",
+ " SortExec: [c3@2 ASC]",
+ " CoalesceBatchesExec: target_batch_size=4096",
+ " RepartitionExec: partitioning=Hash([Column { name: \"c3\", index: 2 }], 2)",
+ " RepartitionExec: partitioning=RoundRobinBatch(2)",
+ " MemoryExec: partitions=1, partition_sizes=[1]",
];
let formatted = displayable(physical_plan.as_ref()).indent().to_string();
let actual: Vec<&str> = formatted.trim().lines().collect();
@@ -2776,3 +2779,140 @@ async fn select_wildcard_with_expr_key_inner_join() -> Result<()> {
Ok(())
}
+
+#[tokio::test]
+async fn join_with_type_coercion_for_equi_expr() -> Result<()> {
+ let ctx = create_join_context("t1_id", "t2_id", false)?;
+
+ let sql = "select t1.t1_id, t1.t1_name, t2.t2_id from t1 inner join t2 on t1.t1_id + 11 = t2.t2_id";
+
+ // assert logical plan
+ let msg = format!("Creating logical plan for '{}'", sql);
+ let plan = ctx
+ .create_logical_plan(&("explain ".to_owned() + sql))
+ .expect(&msg);
+ let state = ctx.state();
+ let plan = state.optimize(&plan)?;
+
+ let expected = vec![
+ "Explain [plan_type:Utf8, plan:Utf8]",
+ " Projection: t1.t1_id, t1.t1_name, t2.t2_id [t1_id:UInt32;N, t1_name:Utf8;N, t2_id:UInt32;N]",
+ " Inner Join: CAST(t1.t1_id AS Int64) + Int64(11) = CAST(t2.t2_id AS Int64) [t1_id:UInt32;N, t1_name:Utf8;N, t2_id:UInt32;N]",
+ " TableScan: t1 projection=[t1_id, t1_name] [t1_id:UInt32;N, t1_name:Utf8;N]",
+ " TableScan: t2 projection=[t2_id] [t2_id:UInt32;N]",
+ ];
+
+ let formatted = plan.display_indent_schema().to_string();
+ let actual: Vec<&str> = formatted.trim().lines().collect();
+ assert_eq!(
+ expected, actual,
+ "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n",
+ expected, actual
+ );
+
+ let expected = vec![
+ "+-------+---------+-------+",
+ "| t1_id | t1_name | t2_id |",
+ "+-------+---------+-------+",
+ "| 11 | a | 22 |",
+ "| 33 | c | 44 |",
+ "| 44 | d | 55 |",
+ "+-------+---------+-------+",
+ ];
+
+ let results = execute_to_batches(&ctx, sql).await;
+ assert_batches_sorted_eq!(expected, &results);
+
+ Ok(())
+}
+
+#[tokio::test]
+async fn join_only_with_filter() -> Result<()> {
+ let ctx = create_join_context("t1_id", "t2_id", false)?;
+
+ let sql = "select t1.t1_id, t1.t1_name, t2.t2_id from t1 inner join t2 on t1.t1_id * 4 < t2.t2_id";
+
+ // assert logical plan
+ let msg = format!("Creating logical plan for '{}'", sql);
+ let plan = ctx
+ .create_logical_plan(&("explain ".to_owned() + sql))
+ .expect(&msg);
+ let state = ctx.state();
+ let plan = state.optimize(&plan)?;
+
+ let expected = vec![
+ "Explain [plan_type:Utf8, plan:Utf8]",
+ " Projection: t1.t1_id, t1.t1_name, t2.t2_id [t1_id:UInt32;N, t1_name:Utf8;N, t2_id:UInt32;N]",
+ " Inner Join: Filter: CAST(t1.t1_id AS Int64) * Int64(4) < CAST(t2.t2_id AS Int64) [t1_id:UInt32;N, t1_name:Utf8;N, t2_id:UInt32;N]",
+ " TableScan: t1 projection=[t1_id, t1_name] [t1_id:UInt32;N, t1_name:Utf8;N]",
+ " TableScan: t2 projection=[t2_id] [t2_id:UInt32;N]",
+ ];
+
+ let formatted = plan.display_indent_schema().to_string();
+ let actual: Vec<&str> = formatted.trim().lines().collect();
+ assert_eq!(
+ expected, actual,
+ "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n",
+ expected, actual
+ );
+
+ let expected = vec![
+ "+-------+---------+-------+",
+ "| t1_id | t1_name | t2_id |",
+ "+-------+---------+-------+",
+ "| 11 | a | 55 |",
+ "+-------+---------+-------+",
+ ];
+
+ let results = execute_to_batches(&ctx, sql).await;
+ assert_batches_sorted_eq!(expected, &results);
+
+ Ok(())
+}
+
+#[tokio::test]
+async fn type_coercion_join_with_filter_and_equi_expr() -> Result<()> {
+ let ctx = create_join_context("t1_id", "t2_id", false)?;
+
+ let sql = "select t1.t1_id, t1.t1_name, t2.t2_id \
+ from t1 \
+ inner join t2 \
+ on t1.t1_id * 5 = t2.t2_id and t1.t1_id * 4 < t2.t2_id";
+
+ // assert logical plan
+ let msg = format!("Creating logical plan for '{}'", sql);
+ let plan = ctx
+ .create_logical_plan(&("explain ".to_owned() + sql))
+ .expect(&msg);
+ let state = ctx.state();
+ let plan = state.optimize(&plan)?;
+
+ let expected = vec![
+ "Explain [plan_type:Utf8, plan:Utf8]",
+ " Projection: t1.t1_id, t1.t1_name, t2.t2_id [t1_id:UInt32;N, t1_name:Utf8;N, t2_id:UInt32;N]",
+ " Inner Join: CAST(t1.t1_id AS Int64) * Int64(5) = CAST(t2.t2_id AS Int64) Filter: CAST(t1.t1_id AS Int64) * Int64(4) < CAST(t2.t2_id AS Int64) [t1_id:UInt32;N, t1_name:Utf8;N, t2_id:UInt32;N]",
+ " TableScan: t1 projection=[t1_id, t1_name] [t1_id:UInt32;N, t1_name:Utf8;N]",
+ " TableScan: t2 projection=[t2_id] [t2_id:UInt32;N]",
+ ];
+
+ let formatted = plan.display_indent_schema().to_string();
+ let actual: Vec<&str> = formatted.trim().lines().collect();
+ assert_eq!(
+ expected, actual,
+ "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n",
+ expected, actual
+ );
+
+ let expected = vec![
+ "+-------+---------+-------+",
+ "| t1_id | t1_name | t2_id |",
+ "+-------+---------+-------+",
+ "| 11 | a | 55 |",
+ "+-------+---------+-------+",
+ ];
+
+ let results = execute_to_batches(&ctx, sql).await;
+ assert_batches_sorted_eq!(expected, &results);
+
+ Ok(())
+}
diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs
index 9d7fdf8f0..9b12287f4 100644
--- a/datafusion/expr/src/logical_plan/plan.rs
+++ b/datafusion/expr/src/logical_plan/plan.rs
@@ -253,9 +253,12 @@ impl LogicalPlan {
aggr_expr,
..
}) => group_expr.iter().chain(aggr_expr.iter()).cloned().collect(),
+ // There are two part of expression for join, equijoin(on) and non-equijoin(filter).
+ // 1. the first part is `on.len()` equijoin expressions, and the struct of each expr is `left-on = right-on`.
+ // 2. the second part is non-equijoin(filter).
LogicalPlan::Join(Join { on, filter, .. }) => on
.iter()
- .flat_map(|(l, r)| vec![l.clone(), r.clone()])
+ .map(|(l, r)| Expr::eq(l.clone(), r.clone()))
.chain(
filter
.as_ref()
diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs
index 89229a3d4..3ee36de17 100644
--- a/datafusion/expr/src/utils.rs
+++ b/datafusion/expr/src/utils.rs
@@ -27,7 +27,8 @@ use crate::logical_plan::{
SubqueryAlias, Union, Values, Window,
};
use crate::{
- Cast, Expr, ExprSchemable, LogicalPlan, LogicalPlanBuilder, TableScan, TryCast,
+ BinaryExpr, Cast, Expr, ExprSchemable, LogicalPlan, LogicalPlanBuilder, Operator,
+ TableScan, TryCast,
};
use arrow::datatypes::{DataType, TimeUnit};
use datafusion_common::{
@@ -567,20 +568,35 @@ pub fn from_plan(
}) => {
let schema =
build_join_schema(inputs[0].schema(), inputs[1].schema(), join_type)?;
+
+ let equi_expr_count = on.len();
+ assert!(expr.len() >= equi_expr_count);
+
+ // The preceding part of expr is equi-exprs,
+ // and the struct of each equi-expr is like `left-expr = right-expr`.
+ let new_on:Vec<(Expr,Expr)> = expr.iter().take(equi_expr_count).map(|equi_expr| {
+ if let Expr::BinaryExpr(BinaryExpr { left, op, right }) = equi_expr {
+ assert!(op == &Operator::Eq);
+ Ok(((**left).clone(), (**right).clone()))
+ } else {
+ Err(DataFusionError::Internal(format!(
+ "The front part expressions should be an binary expression, actual:{}",
+ equi_expr
+ )))
+ }
+ }).collect::<Result<Vec<(Expr, Expr)>>>()?;
+
// Assume that the last expr, if any,
// is the filter_expr (non equality predicate from ON clause)
- let filter_expr = if on.len() * 2 == expr.len() {
- None
- } else {
- Some(expr[expr.len() - 1].clone())
- };
+ let filter_expr =
+ (expr.len() > equi_expr_count).then(|| expr[expr.len() - 1].clone());
Ok(LogicalPlan::Join(Join {
left: Arc::new(inputs[0].clone()),
right: Arc::new(inputs[1].clone()),
join_type: *join_type,
join_constraint: *join_constraint,
- on: on.clone(),
+ on: new_on,
filter: filter_expr,
schema: DFSchemaRef::new(schema),
null_equals_null: *null_equals_null,