You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by al...@apache.org on 2022/12/20 22:15:27 UTC

[arrow-datafusion] branch master updated: Support type coercion for equijoin (#4666)

This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/master by this push:
     new ac2e5d15e Support type coercion for equijoin (#4666)
ac2e5d15e is described below

commit ac2e5d15e5452e83c835d793a95335e87bf35569
Author: ygf11 <ya...@gmail.com>
AuthorDate: Wed Dec 21 06:15:21 2022 +0800

    Support type coercion for equijoin (#4666)
    
    * Support type coercion for equijoin
    
    * fix cargo fmt
    
    * add check for length of join expressions
    
    * Update test for logical conflict
    
    Co-authored-by: Andrew Lamb <an...@nerdnetworks.org>
---
 datafusion/core/tests/sql/joins.rs       | 172 ++++++++++++++++++++++++++++---
 datafusion/expr/src/logical_plan/plan.rs |   5 +-
 datafusion/expr/src/utils.rs             |  30 ++++--
 3 files changed, 183 insertions(+), 24 deletions(-)

diff --git a/datafusion/core/tests/sql/joins.rs b/datafusion/core/tests/sql/joins.rs
index 70b781399..9d7ddc526 100644
--- a/datafusion/core/tests/sql/joins.rs
+++ b/datafusion/core/tests/sql/joins.rs
@@ -1448,11 +1448,11 @@ async fn hash_join_with_decimal() -> Result<()> {
     let state = ctx.state();
     let plan = state.optimize(&plan)?;
     let expected = vec![
-    "Explain [plan_type:Utf8, plan:Utf8]",
-    "  Projection: t1.c1, t1.c2, t1.c3, t1.c4, t2.c1, t2.c2, t2.c3, t2.c4 [c1:Date32;N, c2:Date64;N, c3:Decimal128(5, 2);N, c4:Dictionary(Int32, Utf8);N, c1:Date32;N, c2:Date64;N, c3:Decimal128(10, 2);N, c4:Dictionary(Int32, Utf8);N]",
-    "    Right Join: t1.c3 = t2.c3 [c1:Date32;N, c2:Date64;N, c3:Decimal128(5, 2);N, c4:Dictionary(Int32, Utf8);N, c1:Date32;N, c2:Date64;N, c3:Decimal128(10, 2);N, c4:Dictionary(Int32, Utf8);N]",
-    "      TableScan: t1 projection=[c1, c2, c3, c4] [c1:Date32;N, c2:Date64;N, c3:Decimal128(5, 2);N, c4:Dictionary(Int32, Utf8);N]",
-    "      TableScan: t2 projection=[c1, c2, c3, c4] [c1:Date32;N, c2:Date64;N, c3:Decimal128(10, 2);N, c4:Dictionary(Int32, Utf8);N]",
+        "Explain [plan_type:Utf8, plan:Utf8]",
+        "  Projection: t1.c1, t1.c2, t1.c3, t1.c4, t2.c1, t2.c2, t2.c3, t2.c4 [c1:Date32;N, c2:Date64;N, c3:Decimal128(5, 2);N, c4:Dictionary(Int32, Utf8);N, c1:Date32;N, c2:Date64;N, c3:Decimal128(10, 2);N, c4:Dictionary(Int32, Utf8);N]",
+        "    Right Join: CAST(t1.c3 AS Decimal128(10, 2)) = t2.c3 [c1:Date32;N, c2:Date64;N, c3:Decimal128(5, 2);N, c4:Dictionary(Int32, Utf8);N, c1:Date32;N, c2:Date64;N, c3:Decimal128(10, 2);N, c4:Dictionary(Int32, Utf8);N]",
+        "      TableScan: t1 projection=[c1, c2, c3, c4] [c1:Date32;N, c2:Date64;N, c3:Decimal128(5, 2);N, c4:Dictionary(Int32, Utf8);N]",
+        "      TableScan: t2 projection=[c1, c2, c3, c4] [c1:Date32;N, c2:Date64;N, c3:Decimal128(10, 2);N, c4:Dictionary(Int32, Utf8);N]",
     ];
     let formatted = plan.display_indent_schema().to_string();
     let actual: Vec<&str> = formatted.trim().lines().collect();
@@ -1982,19 +1982,22 @@ async fn sort_merge_join_on_decimal() -> Result<()> {
     let state = ctx.state();
     let logical_plan = state.optimize(&plan)?;
     let physical_plan = state.create_physical_plan(&logical_plan).await?;
+
     let expected = vec![
         "ProjectionExec: expr=[c1@0 as c1, c2@1 as c2, c3@2 as c3, c4@3 as c4, c1@4 as c1, c2@5 as c2, c3@6 as c3, c4@7 as c4]",
-        "  SortMergeJoin: join_type=Right, on=[(Column { name: \"c3\", index: 2 }, Column { name: \"c3\", index: 2 })]",
-        "    SortExec: [c3@2 ASC]",
-        "      CoalesceBatchesExec: target_batch_size=4096",
-        "        RepartitionExec: partitioning=Hash([Column { name: \"c3\", index: 2 }], 2)",
-        "          RepartitionExec: partitioning=RoundRobinBatch(2)",
-        "            MemoryExec: partitions=1, partition_sizes=[1]",
-        "    SortExec: [c3@2 ASC]",
-        "      CoalesceBatchesExec: target_batch_size=4096",
-        "        RepartitionExec: partitioning=Hash([Column { name: \"c3\", index: 2 }], 2)",
-        "          RepartitionExec: partitioning=RoundRobinBatch(2)",
-        "            MemoryExec: partitions=1, partition_sizes=[1]",
+        "  ProjectionExec: expr=[c1@0 as c1, c2@1 as c2, c3@2 as c3, c4@3 as c4, c1@5 as c1, c2@6 as c2, c3@7 as c3, c4@8 as c4]",
+        "    SortMergeJoin: join_type=Right, on=[(Column { name: \"CAST(t1.c3 AS Decimal128(10, 2))\", index: 4 }, Column { name: \"c3\", index: 2 })]",
+        "      SortExec: [CAST(t1.c3 AS Decimal128(10, 2))@4 ASC]",
+        "        CoalesceBatchesExec: target_batch_size=4096",
+        "          RepartitionExec: partitioning=Hash([Column { name: \"CAST(t1.c3 AS Decimal128(10, 2))\", index: 4 }], 2)",
+        "            ProjectionExec: expr=[c1@0 as c1, c2@1 as c2, c3@2 as c3, c4@3 as c4, CAST(c3@2 AS Decimal128(10, 2)) as CAST(t1.c3 AS Decimal128(10, 2))]",
+        "              RepartitionExec: partitioning=RoundRobinBatch(2)",
+        "                MemoryExec: partitions=1, partition_sizes=[1]",
+        "      SortExec: [c3@2 ASC]",
+        "        CoalesceBatchesExec: target_batch_size=4096",
+        "          RepartitionExec: partitioning=Hash([Column { name: \"c3\", index: 2 }], 2)",
+        "            RepartitionExec: partitioning=RoundRobinBatch(2)",
+        "              MemoryExec: partitions=1, partition_sizes=[1]",
     ];
     let formatted = displayable(physical_plan.as_ref()).indent().to_string();
     let actual: Vec<&str> = formatted.trim().lines().collect();
@@ -2776,3 +2779,140 @@ async fn select_wildcard_with_expr_key_inner_join() -> Result<()> {
 
     Ok(())
 }
+
+#[tokio::test]
+async fn join_with_type_coercion_for_equi_expr() -> Result<()> {
+    let ctx = create_join_context("t1_id", "t2_id", false)?;
+
+    let sql = "select t1.t1_id, t1.t1_name, t2.t2_id from t1 inner join t2 on t1.t1_id + 11 = t2.t2_id";
+
+    // assert logical plan
+    let msg = format!("Creating logical plan for '{}'", sql);
+    let plan = ctx
+        .create_logical_plan(&("explain ".to_owned() + sql))
+        .expect(&msg);
+    let state = ctx.state();
+    let plan = state.optimize(&plan)?;
+
+    let expected = vec![
+        "Explain [plan_type:Utf8, plan:Utf8]",
+        "  Projection: t1.t1_id, t1.t1_name, t2.t2_id [t1_id:UInt32;N, t1_name:Utf8;N, t2_id:UInt32;N]",
+        "    Inner Join: CAST(t1.t1_id AS Int64) + Int64(11) = CAST(t2.t2_id AS Int64) [t1_id:UInt32;N, t1_name:Utf8;N, t2_id:UInt32;N]",
+        "      TableScan: t1 projection=[t1_id, t1_name] [t1_id:UInt32;N, t1_name:Utf8;N]",
+        "      TableScan: t2 projection=[t2_id] [t2_id:UInt32;N]",
+    ];
+
+    let formatted = plan.display_indent_schema().to_string();
+    let actual: Vec<&str> = formatted.trim().lines().collect();
+    assert_eq!(
+        expected, actual,
+        "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n",
+        expected, actual
+    );
+
+    let expected = vec![
+        "+-------+---------+-------+",
+        "| t1_id | t1_name | t2_id |",
+        "+-------+---------+-------+",
+        "| 11    | a       | 22    |",
+        "| 33    | c       | 44    |",
+        "| 44    | d       | 55    |",
+        "+-------+---------+-------+",
+    ];
+
+    let results = execute_to_batches(&ctx, sql).await;
+    assert_batches_sorted_eq!(expected, &results);
+
+    Ok(())
+}
+
+#[tokio::test]
+async fn join_only_with_filter() -> Result<()> {
+    let ctx = create_join_context("t1_id", "t2_id", false)?;
+
+    let sql = "select t1.t1_id, t1.t1_name, t2.t2_id from t1 inner join t2 on t1.t1_id * 4 < t2.t2_id";
+
+    // assert logical plan
+    let msg = format!("Creating logical plan for '{}'", sql);
+    let plan = ctx
+        .create_logical_plan(&("explain ".to_owned() + sql))
+        .expect(&msg);
+    let state = ctx.state();
+    let plan = state.optimize(&plan)?;
+
+    let expected = vec![
+        "Explain [plan_type:Utf8, plan:Utf8]",
+        "  Projection: t1.t1_id, t1.t1_name, t2.t2_id [t1_id:UInt32;N, t1_name:Utf8;N, t2_id:UInt32;N]",
+        "    Inner Join:  Filter: CAST(t1.t1_id AS Int64) * Int64(4) < CAST(t2.t2_id AS Int64) [t1_id:UInt32;N, t1_name:Utf8;N, t2_id:UInt32;N]",
+        "      TableScan: t1 projection=[t1_id, t1_name] [t1_id:UInt32;N, t1_name:Utf8;N]",
+        "      TableScan: t2 projection=[t2_id] [t2_id:UInt32;N]",
+    ];
+
+    let formatted = plan.display_indent_schema().to_string();
+    let actual: Vec<&str> = formatted.trim().lines().collect();
+    assert_eq!(
+        expected, actual,
+        "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n",
+        expected, actual
+    );
+
+    let expected = vec![
+        "+-------+---------+-------+",
+        "| t1_id | t1_name | t2_id |",
+        "+-------+---------+-------+",
+        "| 11    | a       | 55    |",
+        "+-------+---------+-------+",
+    ];
+
+    let results = execute_to_batches(&ctx, sql).await;
+    assert_batches_sorted_eq!(expected, &results);
+
+    Ok(())
+}
+
+#[tokio::test]
+async fn type_coercion_join_with_filter_and_equi_expr() -> Result<()> {
+    let ctx = create_join_context("t1_id", "t2_id", false)?;
+
+    let sql = "select t1.t1_id, t1.t1_name, t2.t2_id \
+                     from t1 \
+                     inner join t2 \
+                     on t1.t1_id * 5 = t2.t2_id and t1.t1_id * 4 < t2.t2_id";
+
+    // assert logical plan
+    let msg = format!("Creating logical plan for '{}'", sql);
+    let plan = ctx
+        .create_logical_plan(&("explain ".to_owned() + sql))
+        .expect(&msg);
+    let state = ctx.state();
+    let plan = state.optimize(&plan)?;
+
+    let expected = vec![
+        "Explain [plan_type:Utf8, plan:Utf8]",
+        "  Projection: t1.t1_id, t1.t1_name, t2.t2_id [t1_id:UInt32;N, t1_name:Utf8;N, t2_id:UInt32;N]",
+        "    Inner Join: CAST(t1.t1_id AS Int64) * Int64(5) = CAST(t2.t2_id AS Int64) Filter: CAST(t1.t1_id AS Int64) * Int64(4) < CAST(t2.t2_id AS Int64) [t1_id:UInt32;N, t1_name:Utf8;N, t2_id:UInt32;N]",
+        "      TableScan: t1 projection=[t1_id, t1_name] [t1_id:UInt32;N, t1_name:Utf8;N]",
+        "      TableScan: t2 projection=[t2_id] [t2_id:UInt32;N]",
+    ];
+
+    let formatted = plan.display_indent_schema().to_string();
+    let actual: Vec<&str> = formatted.trim().lines().collect();
+    assert_eq!(
+        expected, actual,
+        "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n",
+        expected, actual
+    );
+
+    let expected = vec![
+        "+-------+---------+-------+",
+        "| t1_id | t1_name | t2_id |",
+        "+-------+---------+-------+",
+        "| 11    | a       | 55    |",
+        "+-------+---------+-------+",
+    ];
+
+    let results = execute_to_batches(&ctx, sql).await;
+    assert_batches_sorted_eq!(expected, &results);
+
+    Ok(())
+}
diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs
index 9d7fdf8f0..9b12287f4 100644
--- a/datafusion/expr/src/logical_plan/plan.rs
+++ b/datafusion/expr/src/logical_plan/plan.rs
@@ -253,9 +253,12 @@ impl LogicalPlan {
                 aggr_expr,
                 ..
             }) => group_expr.iter().chain(aggr_expr.iter()).cloned().collect(),
+            // There are two part of expression for join, equijoin(on) and non-equijoin(filter).
+            // 1. the first part is `on.len()` equijoin expressions, and the struct of each expr is `left-on = right-on`.
+            // 2. the second part is non-equijoin(filter).
             LogicalPlan::Join(Join { on, filter, .. }) => on
                 .iter()
-                .flat_map(|(l, r)| vec![l.clone(), r.clone()])
+                .map(|(l, r)| Expr::eq(l.clone(), r.clone()))
                 .chain(
                     filter
                         .as_ref()
diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs
index 89229a3d4..3ee36de17 100644
--- a/datafusion/expr/src/utils.rs
+++ b/datafusion/expr/src/utils.rs
@@ -27,7 +27,8 @@ use crate::logical_plan::{
     SubqueryAlias, Union, Values, Window,
 };
 use crate::{
-    Cast, Expr, ExprSchemable, LogicalPlan, LogicalPlanBuilder, TableScan, TryCast,
+    BinaryExpr, Cast, Expr, ExprSchemable, LogicalPlan, LogicalPlanBuilder, Operator,
+    TableScan, TryCast,
 };
 use arrow::datatypes::{DataType, TimeUnit};
 use datafusion_common::{
@@ -567,20 +568,35 @@ pub fn from_plan(
         }) => {
             let schema =
                 build_join_schema(inputs[0].schema(), inputs[1].schema(), join_type)?;
+
+            let equi_expr_count = on.len();
+            assert!(expr.len() >= equi_expr_count);
+
+            // The preceding part of expr is equi-exprs,
+            // and the struct of each equi-expr is like `left-expr = right-expr`.
+            let new_on:Vec<(Expr,Expr)> = expr.iter().take(equi_expr_count).map(|equi_expr| {
+                    if let Expr::BinaryExpr(BinaryExpr { left, op, right }) = equi_expr {
+                        assert!(op == &Operator::Eq);
+                        Ok(((**left).clone(), (**right).clone()))
+                    } else {
+                        Err(DataFusionError::Internal(format!(
+                            "The front part expressions should be an binary expression, actual:{}",
+                            equi_expr
+                        )))
+                    }
+                }).collect::<Result<Vec<(Expr, Expr)>>>()?;
+
             // Assume that the last expr, if any,
             // is the filter_expr (non equality predicate from ON clause)
-            let filter_expr = if on.len() * 2 == expr.len() {
-                None
-            } else {
-                Some(expr[expr.len() - 1].clone())
-            };
+            let filter_expr =
+                (expr.len() > equi_expr_count).then(|| expr[expr.len() - 1].clone());
 
             Ok(LogicalPlan::Join(Join {
                 left: Arc::new(inputs[0].clone()),
                 right: Arc::new(inputs[1].clone()),
                 join_type: *join_type,
                 join_constraint: *join_constraint,
-                on: on.clone(),
+                on: new_on,
                 filter: filter_expr,
                 schema: DFSchemaRef::new(schema),
                 null_equals_null: *null_equals_null,