You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by al...@apache.org on 2022/06/09 10:41:23 UTC

[arrow-datafusion] branch master updated: Make sure that the data types are supported in hashjoin before genera… (#2702)

This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/master by this push:
     new 67d91a7f1 Make sure that the data types are supported in hashjoin before genera… (#2702)
67d91a7f1 is described below

commit 67d91a7f1f26a795966c4cc0b200187778ee840c
Author: AssHero <hu...@gmail.com>
AuthorDate: Thu Jun 9 18:41:18 2022 +0800

    Make sure that the data types are supported in hashjoin before genera… (#2702)
    
    * Make sure that the data types are supported in hashjoin before generating hashjoin logical plan.
    If data types are not supported in hashjoin, try cross join with filters.
    
    * format the expected output of join test case
    
    * refine the code
    
    * refine the code according to code review's comments
    
    * add comments
---
 datafusion/core/src/physical_plan/hash_join.rs |   2 +
 datafusion/core/tests/sql/joins.rs             | 138 +++++++++++++++++++++++++
 datafusion/expr/src/logical_plan/builder.rs    |  54 +++++++---
 datafusion/expr/src/utils.rs                   |  30 ++++++
 datafusion/sql/src/planner.rs                  |  57 ++++++++--
 5 files changed, 263 insertions(+), 18 deletions(-)

diff --git a/datafusion/core/src/physical_plan/hash_join.rs b/datafusion/core/src/physical_plan/hash_join.rs
index 155d1c740..ee58206a7 100644
--- a/datafusion/core/src/physical_plan/hash_join.rs
+++ b/datafusion/core/src/physical_plan/hash_join.rs
@@ -947,6 +947,8 @@ macro_rules! equal_rows_elem {
 }
 
 /// Left and right row have equal values
+/// If more data types are supported here, please also add the data types in can_hash function
+/// to generate hash join logical plan.
 fn equal_rows(
     left: usize,
     right: usize,
diff --git a/datafusion/core/tests/sql/joins.rs b/datafusion/core/tests/sql/joins.rs
index df853b7c9..6b3b8c339 100644
--- a/datafusion/core/tests/sql/joins.rs
+++ b/datafusion/core/tests/sql/joins.rs
@@ -1204,3 +1204,141 @@ async fn join_partitioned() -> Result<()> {
 
     Ok(())
 }
+
+#[tokio::test]
+async fn join_with_hash_unsupported_data_type() -> Result<()> {
+    let ctx = SessionContext::new();
+
+    let schema = Schema::new(vec![
+        Field::new("c1", DataType::Int32, true),
+        Field::new("c2", DataType::Utf8, true),
+        Field::new("c3", DataType::Int64, true),
+        Field::new("c4", DataType::Date32, true),
+    ]);
+    let data = RecordBatch::try_new(
+        Arc::new(schema),
+        vec![
+            Arc::new(Int32Array::from_slice(&[1, 2, 3])),
+            Arc::new(StringArray::from_slice(&["aaa", "bbb", "ccc"])),
+            Arc::new(Int64Array::from_slice(&[100, 200, 300])),
+            Arc::new(Date32Array::from(vec![Some(1), Some(2), Some(3)])),
+        ],
+    )?;
+    let table = MemTable::try_new(data.schema(), vec![vec![data]])?;
+    ctx.register_table("foo", Arc::new(table))?;
+
+    // join on hash unsupported data type (Date32), use cross join instead hash join
+    let sql = "select * from foo t1 join foo t2 on t1.c4 = t2.c4";
+    let msg = format!("Creating logical plan for '{}'", sql);
+    let plan = ctx
+        .create_logical_plan(&("explain ".to_owned() + sql))
+        .expect(&msg);
+    let state = ctx.state();
+    let plan = state.optimize(&plan)?;
+    let expected = vec![
+        "Explain [plan_type:Utf8, plan:Utf8]",
+        "  Projection: #t1.c1, #t1.c2, #t1.c3, #t1.c4, #t2.c1, #t2.c2, #t2.c3, #t2.c4 [c1:Int32;N, c2:Utf8;N, c3:Int64;N, c4:Date32;N, c1:Int32;N, c2:Utf8;N, c3:Int64;N, c4:Date32;N]",
+        "    Filter: #t1.c4 = #t2.c4 [c1:Int32;N, c2:Utf8;N, c3:Int64;N, c4:Date32;N, c1:Int32;N, c2:Utf8;N, c3:Int64;N, c4:Date32;N]",
+        "      CrossJoin: [c1:Int32;N, c2:Utf8;N, c3:Int64;N, c4:Date32;N, c1:Int32;N, c2:Utf8;N, c3:Int64;N, c4:Date32;N]",
+        "        SubqueryAlias: t1 [c1:Int32;N, c2:Utf8;N, c3:Int64;N, c4:Date32;N]",
+        "          TableScan: foo projection=Some([c1, c2, c3, c4]) [c1:Int32;N, c2:Utf8;N, c3:Int64;N, c4:Date32;N]",
+        "        SubqueryAlias: t2 [c1:Int32;N, c2:Utf8;N, c3:Int64;N, c4:Date32;N]",
+        "          TableScan: foo projection=Some([c1, c2, c3, c4]) [c1:Int32;N, c2:Utf8;N, c3:Int64;N, c4:Date32;N]",
+    ];
+    let formatted = plan.display_indent_schema().to_string();
+    let actual: Vec<&str> = formatted.trim().lines().collect();
+    assert_eq!(
+        expected, actual,
+        "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n",
+        expected, actual
+    );
+
+    let expected = vec![
+        "+----+-----+-----+------------+----+-----+-----+------------+",
+        "| c1 | c2  | c3  | c4         | c1 | c2  | c3  | c4         |",
+        "+----+-----+-----+------------+----+-----+-----+------------+",
+        "| 1  | aaa | 100 | 1970-01-02 | 1  | aaa | 100 | 1970-01-02 |",
+        "| 2  | bbb | 200 | 1970-01-03 | 2  | bbb | 200 | 1970-01-03 |",
+        "| 3  | ccc | 300 | 1970-01-04 | 3  | ccc | 300 | 1970-01-04 |",
+        "+----+-----+-----+------------+----+-----+-----+------------+",
+    ];
+
+    let results = execute_to_batches(&ctx, sql).await;
+    assert_batches_sorted_eq!(expected, &results);
+
+    // join on hash supported data type (Int32), use hash join
+    let sql = "select * from foo t1 join foo t2 on t1.c1 = t2.c1";
+    let plan = ctx
+        .create_logical_plan(&("explain ".to_owned() + sql))
+        .expect(&msg);
+    let plan = state.optimize(&plan)?;
+    let expected = vec![
+        "Explain [plan_type:Utf8, plan:Utf8]",
+        "  Projection: #t1.c1, #t1.c2, #t1.c3, #t1.c4, #t2.c1, #t2.c2, #t2.c3, #t2.c4 [c1:Int32;N, c2:Utf8;N, c3:Int64;N, c4:Date32;N, c1:Int32;N, c2:Utf8;N, c3:Int64;N, c4:Date32;N]",
+        "    Inner Join: #t1.c1 = #t2.c1 [c1:Int32;N, c2:Utf8;N, c3:Int64;N, c4:Date32;N, c1:Int32;N, c2:Utf8;N, c3:Int64;N, c4:Date32;N]",
+        "      SubqueryAlias: t1 [c1:Int32;N, c2:Utf8;N, c3:Int64;N, c4:Date32;N]",
+        "        TableScan: foo projection=Some([c1, c2, c3, c4]) [c1:Int32;N, c2:Utf8;N, c3:Int64;N, c4:Date32;N]",
+        "      SubqueryAlias: t2 [c1:Int32;N, c2:Utf8;N, c3:Int64;N, c4:Date32;N]",
+        "        TableScan: foo projection=Some([c1, c2, c3, c4]) [c1:Int32;N, c2:Utf8;N, c3:Int64;N, c4:Date32;N]",
+    ];
+    let formatted = plan.display_indent_schema().to_string();
+    let actual: Vec<&str> = formatted.trim().lines().collect();
+    assert_eq!(
+        expected, actual,
+        "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n",
+        expected, actual
+    );
+
+    let expected = vec![
+        "+----+-----+-----+------------+----+-----+-----+------------+",
+        "| c1 | c2  | c3  | c4         | c1 | c2  | c3  | c4         |",
+        "+----+-----+-----+------------+----+-----+-----+------------+",
+        "| 1  | aaa | 100 | 1970-01-02 | 1  | aaa | 100 | 1970-01-02 |",
+        "| 2  | bbb | 200 | 1970-01-03 | 2  | bbb | 200 | 1970-01-03 |",
+        "| 3  | ccc | 300 | 1970-01-04 | 3  | ccc | 300 | 1970-01-04 |",
+        "+----+-----+-----+------------+----+-----+-----+------------+",
+    ];
+
+    let results = execute_to_batches(&ctx, sql).await;
+    assert_batches_sorted_eq!(expected, &results);
+
+    // join on two columns, hash supported data type(Int64) and hash unsupported data type (Date32),
+    // use hash join on Int64 column, and filter on Date32 column.
+    let sql = "select * from foo t1, foo t2 where t1.c3 = t2.c3 and t1.c4 = t2.c4";
+    let plan = ctx
+        .create_logical_plan(&("explain ".to_owned() + sql))
+        .expect(&msg);
+    let plan = state.optimize(&plan)?;
+    let expected = vec![
+        "Explain [plan_type:Utf8, plan:Utf8]",
+        "  Projection: #t1.c1, #t1.c2, #t1.c3, #t1.c4, #t2.c1, #t2.c2, #t2.c3, #t2.c4 [c1:Int32;N, c2:Utf8;N, c3:Int64;N, c4:Date32;N, c1:Int32;N, c2:Utf8;N, c3:Int64;N, c4:Date32;N]",
+        "    Filter: #t1.c4 = #t2.c4 [c1:Int32;N, c2:Utf8;N, c3:Int64;N, c4:Date32;N, c1:Int32;N, c2:Utf8;N, c3:Int64;N, c4:Date32;N]",
+        "      Inner Join: #t1.c3 = #t2.c3 [c1:Int32;N, c2:Utf8;N, c3:Int64;N, c4:Date32;N, c1:Int32;N, c2:Utf8;N, c3:Int64;N, c4:Date32;N]",
+        "        SubqueryAlias: t1 [c1:Int32;N, c2:Utf8;N, c3:Int64;N, c4:Date32;N]",
+        "          TableScan: foo projection=Some([c1, c2, c3, c4]) [c1:Int32;N, c2:Utf8;N, c3:Int64;N, c4:Date32;N]",
+        "        SubqueryAlias: t2 [c1:Int32;N, c2:Utf8;N, c3:Int64;N, c4:Date32;N]",
+        "          TableScan: foo projection=Some([c1, c2, c3, c4]) [c1:Int32;N, c2:Utf8;N, c3:Int64;N, c4:Date32;N]",
+    ];
+    let formatted = plan.display_indent_schema().to_string();
+    let actual: Vec<&str> = formatted.trim().lines().collect();
+    assert_eq!(
+        expected, actual,
+        "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n",
+        expected, actual
+    );
+
+    let expected = vec![
+        "+----+-----+-----+------------+----+-----+-----+------------+",
+        "| c1 | c2  | c3  | c4         | c1 | c2  | c3  | c4         |",
+        "+----+-----+-----+------------+----+-----+-----+------------+",
+        "| 1  | aaa | 100 | 1970-01-02 | 1  | aaa | 100 | 1970-01-02 |",
+        "| 2  | bbb | 200 | 1970-01-03 | 2  | bbb | 200 | 1970-01-03 |",
+        "| 3  | ccc | 300 | 1970-01-04 | 3  | ccc | 300 | 1970-01-04 |",
+        "+----+-----+-----+------------+----+-----+-----+------------+",
+    ];
+
+    let results = execute_to_batches(&ctx, sql).await;
+    assert_batches_sorted_eq!(expected, &results);
+
+    Ok(())
+}
diff --git a/datafusion/expr/src/logical_plan/builder.rs b/datafusion/expr/src/logical_plan/builder.rs
index d0309b95f..509b47909 100644
--- a/datafusion/expr/src/logical_plan/builder.rs
+++ b/datafusion/expr/src/logical_plan/builder.rs
@@ -19,6 +19,7 @@
 
 use crate::expr_rewriter::{normalize_col, normalize_cols, rewrite_sort_cols_by_aggs};
 use crate::utils::{columnize_expr, exprlist_to_fields, from_plan};
+use crate::{and, binary_expr, Operator};
 use crate::{
     logical_plan::{
         Aggregate, Analyze, CrossJoin, EmptyRelation, Explain, Filter, Join,
@@ -27,7 +28,7 @@ use crate::{
         Window,
     },
     utils::{
-        expand_qualified_wildcard, expand_wildcard, expr_to_columns,
+        can_hash, expand_qualified_wildcard, expand_wildcard, expr_to_columns,
         group_window_expr_by_sort_keys,
     },
     Expr, ExprSchemable, TableSource,
@@ -603,17 +604,46 @@ impl LogicalPlanBuilder {
         let on: Vec<(_, _)> = left_keys.into_iter().zip(right_keys.into_iter()).collect();
         let join_schema =
             build_join_schema(self.plan.schema(), right.schema(), &join_type)?;
-
-        Ok(Self::from(LogicalPlan::Join(Join {
-            left: Arc::new(self.plan.clone()),
-            right: Arc::new(right.clone()),
-            on,
-            filter: None,
-            join_type,
-            join_constraint: JoinConstraint::Using,
-            schema: DFSchemaRef::new(join_schema),
-            null_equals_null: false,
-        })))
+        let mut join_on: Vec<(Column, Column)> = vec![];
+        let mut filters: Option<Expr> = None;
+        for (l, r) in &on {
+            if self.plan.schema().field_from_column(l).is_ok()
+                && right.schema().field_from_column(r).is_ok()
+                && can_hash(self.plan.schema().field_from_column(l).unwrap().data_type())
+            {
+                join_on.push((l.clone(), r.clone()));
+            } else if self.plan.schema().field_from_column(r).is_ok()
+                && right.schema().field_from_column(l).is_ok()
+                && can_hash(self.plan.schema().field_from_column(r).unwrap().data_type())
+            {
+                join_on.push((r.clone(), l.clone()));
+            } else {
+                let expr = binary_expr(
+                    Expr::Column(l.clone()),
+                    Operator::Eq,
+                    Expr::Column(r.clone()),
+                );
+                match filters {
+                    None => filters = Some(expr),
+                    Some(filter_expr) => filters = Some(and(expr, filter_expr)),
+                }
+            }
+        }
+        if join_on.is_empty() {
+            let join = Self::from(self.plan.clone()).cross_join(&right.clone())?;
+            join.filter(filters.unwrap())
+        } else {
+            Ok(Self::from(LogicalPlan::Join(Join {
+                left: Arc::new(self.plan.clone()),
+                right: Arc::new(right.clone()),
+                on: join_on,
+                filter: filters,
+                join_type,
+                join_constraint: JoinConstraint::Using,
+                schema: DFSchemaRef::new(join_schema),
+                null_equals_null: false,
+            })))
+        }
     }
 
     /// Apply a cross join
diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs
index 065e6120b..2120acaed 100644
--- a/datafusion/expr/src/utils.rs
+++ b/datafusion/expr/src/utils.rs
@@ -25,6 +25,7 @@ use crate::logical_plan::{
     Window,
 };
 use crate::{Expr, ExprSchemable, LogicalPlan, LogicalPlanBuilder};
+use arrow::datatypes::{DataType, TimeUnit};
 use datafusion_common::{
     Column, DFField, DFSchema, DFSchemaRef, DataFusionError, Result, ScalarValue,
 };
@@ -640,6 +641,35 @@ pub fn expr_as_column_expr(expr: &Expr, plan: &LogicalPlan) -> Result<Expr> {
     }
 }
 
+/// can this data type be used in hash join equal conditions??
+/// data types here come from function 'equal_rows', if more data types are supported
+/// in equal_rows(hash join), add those data types here to generate join logical plan.
+pub fn can_hash(data_type: &DataType) -> bool {
+    match data_type {
+        DataType::Null => true,
+        DataType::Boolean => true,
+        DataType::Int8 => true,
+        DataType::Int16 => true,
+        DataType::Int32 => true,
+        DataType::Int64 => true,
+        DataType::UInt8 => true,
+        DataType::UInt16 => true,
+        DataType::UInt32 => true,
+        DataType::UInt64 => true,
+        DataType::Float32 => true,
+        DataType::Float64 => true,
+        DataType::Timestamp(time_unit, None) => match time_unit {
+            TimeUnit::Second => true,
+            TimeUnit::Millisecond => true,
+            TimeUnit::Microsecond => true,
+            TimeUnit::Nanosecond => true,
+        },
+        DataType::Utf8 => true,
+        DataType::LargeUtf8 => true,
+        _ => false,
+    }
+}
+
 #[cfg(test)]
 mod tests {
     use super::*;
diff --git a/datafusion/sql/src/planner.rs b/datafusion/sql/src/planner.rs
index 2a33be32d..5978b77ee 100644
--- a/datafusion/sql/src/planner.rs
+++ b/datafusion/sql/src/planner.rs
@@ -29,8 +29,9 @@ use datafusion_expr::logical_plan::{
     ToStringifiedPlan,
 };
 use datafusion_expr::utils::{
-    expand_qualified_wildcard, expand_wildcard, expr_as_column_expr, expr_to_columns,
-    find_aggregate_exprs, find_column_exprs, find_window_exprs, COUNT_STAR_EXPANSION,
+    can_hash, expand_qualified_wildcard, expand_wildcard, expr_as_column_expr,
+    expr_to_columns, find_aggregate_exprs, find_column_exprs, find_window_exprs,
+    COUNT_STAR_EXPANSION,
 };
 use datafusion_expr::{
     and, col, lit, AggregateFunction, AggregateUDF, Expr, Operator, ScalarUDF,
@@ -594,7 +595,13 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
                 let mut filter = vec![];
 
                 // extract join keys
-                extract_join_keys(expr, &mut keys, &mut filter);
+                extract_join_keys(
+                    expr,
+                    &mut keys,
+                    &mut filter,
+                    left.schema(),
+                    right.schema(),
+                );
 
                 let (left_keys, right_keys): (Vec<Column>, Vec<Column>) =
                     keys.into_iter().unzip();
@@ -813,10 +820,22 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
                             for (l, r) in &possible_join_keys {
                                 if left_schema.field_from_column(l).is_ok()
                                     && right_schema.field_from_column(r).is_ok()
+                                    && can_hash(
+                                        left_schema
+                                            .field_from_column(l)
+                                            .unwrap()
+                                            .data_type(),
+                                    )
                                 {
                                     join_keys.push((l.clone(), r.clone()));
                                 } else if left_schema.field_from_column(r).is_ok()
                                     && right_schema.field_from_column(l).is_ok()
+                                    && can_hash(
+                                        left_schema
+                                            .field_from_column(r)
+                                            .unwrap()
+                                            .data_type(),
+                                    )
                                 {
                                     join_keys.push((r.clone(), l.clone()));
                                 }
@@ -2512,12 +2531,26 @@ fn extract_join_keys(
     expr: Expr,
     accum: &mut Vec<(Column, Column)>,
     accum_filter: &mut Vec<Expr>,
+    left_schema: &Arc<DFSchema>,
+    right_schema: &Arc<DFSchema>,
 ) {
     match &expr {
         Expr::BinaryExpr { left, op, right } => match op {
             Operator::Eq => match (left.as_ref(), right.as_ref()) {
                 (Expr::Column(l), Expr::Column(r)) => {
-                    accum.push((l.clone(), r.clone()));
+                    if left_schema.field_from_column(l).is_ok()
+                        && right_schema.field_from_column(r).is_ok()
+                        && can_hash(left_schema.field_from_column(l).unwrap().data_type())
+                    {
+                        accum.push((l.clone(), r.clone()));
+                    } else if left_schema.field_from_column(r).is_ok()
+                        && right_schema.field_from_column(l).is_ok()
+                        && can_hash(left_schema.field_from_column(r).unwrap().data_type())
+                    {
+                        accum.push((r.clone(), l.clone()));
+                    } else {
+                        accum_filter.push(expr);
+                    }
                 }
                 _other => {
                     accum_filter.push(expr);
@@ -2525,8 +2558,20 @@ fn extract_join_keys(
             },
             Operator::And => {
                 if let Expr::BinaryExpr { left, op: _, right } = expr {
-                    extract_join_keys(*left, accum, accum_filter);
-                    extract_join_keys(*right, accum, accum_filter);
+                    extract_join_keys(
+                        *left,
+                        accum,
+                        accum_filter,
+                        left_schema,
+                        right_schema,
+                    );
+                    extract_join_keys(
+                        *right,
+                        accum,
+                        accum_filter,
+                        left_schema,
+                        right_schema,
+                    );
                 }
             }
             _other => {