You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by al...@apache.org on 2022/06/09 10:41:23 UTC
[arrow-datafusion] branch master updated: Make sure that the data types are supported in hashjoin before genera… (#2702)
This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/master by this push:
new 67d91a7f1 Make sure that the data types are supported in hashjoin before genera… (#2702)
67d91a7f1 is described below
commit 67d91a7f1f26a795966c4cc0b200187778ee840c
Author: AssHero <hu...@gmail.com>
AuthorDate: Thu Jun 9 18:41:18 2022 +0800
Make sure that the data types are supported in hashjoin before genera… (#2702)
* Make sure that the data types are supported in hashjoin before generating hashjoin logical plan.
If data types are not supported in hashjoin, try cross join with filters.
* format the expected output of join test case
* refine the code
* refine the code according to code review's comments
* add comments
---
datafusion/core/src/physical_plan/hash_join.rs | 2 +
datafusion/core/tests/sql/joins.rs | 138 +++++++++++++++++++++++++
datafusion/expr/src/logical_plan/builder.rs | 54 +++++++---
datafusion/expr/src/utils.rs | 30 ++++++
datafusion/sql/src/planner.rs | 57 ++++++++--
5 files changed, 263 insertions(+), 18 deletions(-)
diff --git a/datafusion/core/src/physical_plan/hash_join.rs b/datafusion/core/src/physical_plan/hash_join.rs
index 155d1c740..ee58206a7 100644
--- a/datafusion/core/src/physical_plan/hash_join.rs
+++ b/datafusion/core/src/physical_plan/hash_join.rs
@@ -947,6 +947,8 @@ macro_rules! equal_rows_elem {
}
/// Left and right row have equal values
+/// If more data types are supported here, please also add the data types in can_hash function
+/// to generate hash join logical plan.
fn equal_rows(
left: usize,
right: usize,
diff --git a/datafusion/core/tests/sql/joins.rs b/datafusion/core/tests/sql/joins.rs
index df853b7c9..6b3b8c339 100644
--- a/datafusion/core/tests/sql/joins.rs
+++ b/datafusion/core/tests/sql/joins.rs
@@ -1204,3 +1204,141 @@ async fn join_partitioned() -> Result<()> {
Ok(())
}
+
+#[tokio::test]
+async fn join_with_hash_unsupported_data_type() -> Result<()> {
+ let ctx = SessionContext::new();
+
+ let schema = Schema::new(vec![
+ Field::new("c1", DataType::Int32, true),
+ Field::new("c2", DataType::Utf8, true),
+ Field::new("c3", DataType::Int64, true),
+ Field::new("c4", DataType::Date32, true),
+ ]);
+ let data = RecordBatch::try_new(
+ Arc::new(schema),
+ vec![
+ Arc::new(Int32Array::from_slice(&[1, 2, 3])),
+ Arc::new(StringArray::from_slice(&["aaa", "bbb", "ccc"])),
+ Arc::new(Int64Array::from_slice(&[100, 200, 300])),
+ Arc::new(Date32Array::from(vec![Some(1), Some(2), Some(3)])),
+ ],
+ )?;
+ let table = MemTable::try_new(data.schema(), vec![vec![data]])?;
+ ctx.register_table("foo", Arc::new(table))?;
+
+ // join on hash unsupported data type (Date32), use cross join instead hash join
+ let sql = "select * from foo t1 join foo t2 on t1.c4 = t2.c4";
+ let msg = format!("Creating logical plan for '{}'", sql);
+ let plan = ctx
+ .create_logical_plan(&("explain ".to_owned() + sql))
+ .expect(&msg);
+ let state = ctx.state();
+ let plan = state.optimize(&plan)?;
+ let expected = vec![
+ "Explain [plan_type:Utf8, plan:Utf8]",
+ " Projection: #t1.c1, #t1.c2, #t1.c3, #t1.c4, #t2.c1, #t2.c2, #t2.c3, #t2.c4 [c1:Int32;N, c2:Utf8;N, c3:Int64;N, c4:Date32;N, c1:Int32;N, c2:Utf8;N, c3:Int64;N, c4:Date32;N]",
+ " Filter: #t1.c4 = #t2.c4 [c1:Int32;N, c2:Utf8;N, c3:Int64;N, c4:Date32;N, c1:Int32;N, c2:Utf8;N, c3:Int64;N, c4:Date32;N]",
+ " CrossJoin: [c1:Int32;N, c2:Utf8;N, c3:Int64;N, c4:Date32;N, c1:Int32;N, c2:Utf8;N, c3:Int64;N, c4:Date32;N]",
+ " SubqueryAlias: t1 [c1:Int32;N, c2:Utf8;N, c3:Int64;N, c4:Date32;N]",
+ " TableScan: foo projection=Some([c1, c2, c3, c4]) [c1:Int32;N, c2:Utf8;N, c3:Int64;N, c4:Date32;N]",
+ " SubqueryAlias: t2 [c1:Int32;N, c2:Utf8;N, c3:Int64;N, c4:Date32;N]",
+ " TableScan: foo projection=Some([c1, c2, c3, c4]) [c1:Int32;N, c2:Utf8;N, c3:Int64;N, c4:Date32;N]",
+ ];
+ let formatted = plan.display_indent_schema().to_string();
+ let actual: Vec<&str> = formatted.trim().lines().collect();
+ assert_eq!(
+ expected, actual,
+ "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n",
+ expected, actual
+ );
+
+ let expected = vec![
+ "+----+-----+-----+------------+----+-----+-----+------------+",
+ "| c1 | c2 | c3 | c4 | c1 | c2 | c3 | c4 |",
+ "+----+-----+-----+------------+----+-----+-----+------------+",
+ "| 1 | aaa | 100 | 1970-01-02 | 1 | aaa | 100 | 1970-01-02 |",
+ "| 2 | bbb | 200 | 1970-01-03 | 2 | bbb | 200 | 1970-01-03 |",
+ "| 3 | ccc | 300 | 1970-01-04 | 3 | ccc | 300 | 1970-01-04 |",
+ "+----+-----+-----+------------+----+-----+-----+------------+",
+ ];
+
+ let results = execute_to_batches(&ctx, sql).await;
+ assert_batches_sorted_eq!(expected, &results);
+
+ // join on hash supported data type (Int32), use hash join
+ let sql = "select * from foo t1 join foo t2 on t1.c1 = t2.c1";
+ let plan = ctx
+ .create_logical_plan(&("explain ".to_owned() + sql))
+ .expect(&msg);
+ let plan = state.optimize(&plan)?;
+ let expected = vec![
+ "Explain [plan_type:Utf8, plan:Utf8]",
+ " Projection: #t1.c1, #t1.c2, #t1.c3, #t1.c4, #t2.c1, #t2.c2, #t2.c3, #t2.c4 [c1:Int32;N, c2:Utf8;N, c3:Int64;N, c4:Date32;N, c1:Int32;N, c2:Utf8;N, c3:Int64;N, c4:Date32;N]",
+ " Inner Join: #t1.c1 = #t2.c1 [c1:Int32;N, c2:Utf8;N, c3:Int64;N, c4:Date32;N, c1:Int32;N, c2:Utf8;N, c3:Int64;N, c4:Date32;N]",
+ " SubqueryAlias: t1 [c1:Int32;N, c2:Utf8;N, c3:Int64;N, c4:Date32;N]",
+ " TableScan: foo projection=Some([c1, c2, c3, c4]) [c1:Int32;N, c2:Utf8;N, c3:Int64;N, c4:Date32;N]",
+ " SubqueryAlias: t2 [c1:Int32;N, c2:Utf8;N, c3:Int64;N, c4:Date32;N]",
+ " TableScan: foo projection=Some([c1, c2, c3, c4]) [c1:Int32;N, c2:Utf8;N, c3:Int64;N, c4:Date32;N]",
+ ];
+ let formatted = plan.display_indent_schema().to_string();
+ let actual: Vec<&str> = formatted.trim().lines().collect();
+ assert_eq!(
+ expected, actual,
+ "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n",
+ expected, actual
+ );
+
+ let expected = vec![
+ "+----+-----+-----+------------+----+-----+-----+------------+",
+ "| c1 | c2 | c3 | c4 | c1 | c2 | c3 | c4 |",
+ "+----+-----+-----+------------+----+-----+-----+------------+",
+ "| 1 | aaa | 100 | 1970-01-02 | 1 | aaa | 100 | 1970-01-02 |",
+ "| 2 | bbb | 200 | 1970-01-03 | 2 | bbb | 200 | 1970-01-03 |",
+ "| 3 | ccc | 300 | 1970-01-04 | 3 | ccc | 300 | 1970-01-04 |",
+ "+----+-----+-----+------------+----+-----+-----+------------+",
+ ];
+
+ let results = execute_to_batches(&ctx, sql).await;
+ assert_batches_sorted_eq!(expected, &results);
+
+ // join on two columns, hash supported data type(Int64) and hash unsupported data type (Date32),
+ // use hash join on Int64 column, and filter on Date32 column.
+ let sql = "select * from foo t1, foo t2 where t1.c3 = t2.c3 and t1.c4 = t2.c4";
+ let plan = ctx
+ .create_logical_plan(&("explain ".to_owned() + sql))
+ .expect(&msg);
+ let plan = state.optimize(&plan)?;
+ let expected = vec![
+ "Explain [plan_type:Utf8, plan:Utf8]",
+ " Projection: #t1.c1, #t1.c2, #t1.c3, #t1.c4, #t2.c1, #t2.c2, #t2.c3, #t2.c4 [c1:Int32;N, c2:Utf8;N, c3:Int64;N, c4:Date32;N, c1:Int32;N, c2:Utf8;N, c3:Int64;N, c4:Date32;N]",
+ " Filter: #t1.c4 = #t2.c4 [c1:Int32;N, c2:Utf8;N, c3:Int64;N, c4:Date32;N, c1:Int32;N, c2:Utf8;N, c3:Int64;N, c4:Date32;N]",
+ " Inner Join: #t1.c3 = #t2.c3 [c1:Int32;N, c2:Utf8;N, c3:Int64;N, c4:Date32;N, c1:Int32;N, c2:Utf8;N, c3:Int64;N, c4:Date32;N]",
+ " SubqueryAlias: t1 [c1:Int32;N, c2:Utf8;N, c3:Int64;N, c4:Date32;N]",
+ " TableScan: foo projection=Some([c1, c2, c3, c4]) [c1:Int32;N, c2:Utf8;N, c3:Int64;N, c4:Date32;N]",
+ " SubqueryAlias: t2 [c1:Int32;N, c2:Utf8;N, c3:Int64;N, c4:Date32;N]",
+ " TableScan: foo projection=Some([c1, c2, c3, c4]) [c1:Int32;N, c2:Utf8;N, c3:Int64;N, c4:Date32;N]",
+ ];
+ let formatted = plan.display_indent_schema().to_string();
+ let actual: Vec<&str> = formatted.trim().lines().collect();
+ assert_eq!(
+ expected, actual,
+ "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n",
+ expected, actual
+ );
+
+ let expected = vec![
+ "+----+-----+-----+------------+----+-----+-----+------------+",
+ "| c1 | c2 | c3 | c4 | c1 | c2 | c3 | c4 |",
+ "+----+-----+-----+------------+----+-----+-----+------------+",
+ "| 1 | aaa | 100 | 1970-01-02 | 1 | aaa | 100 | 1970-01-02 |",
+ "| 2 | bbb | 200 | 1970-01-03 | 2 | bbb | 200 | 1970-01-03 |",
+ "| 3 | ccc | 300 | 1970-01-04 | 3 | ccc | 300 | 1970-01-04 |",
+ "+----+-----+-----+------------+----+-----+-----+------------+",
+ ];
+
+ let results = execute_to_batches(&ctx, sql).await;
+ assert_batches_sorted_eq!(expected, &results);
+
+ Ok(())
+}
diff --git a/datafusion/expr/src/logical_plan/builder.rs b/datafusion/expr/src/logical_plan/builder.rs
index d0309b95f..509b47909 100644
--- a/datafusion/expr/src/logical_plan/builder.rs
+++ b/datafusion/expr/src/logical_plan/builder.rs
@@ -19,6 +19,7 @@
use crate::expr_rewriter::{normalize_col, normalize_cols, rewrite_sort_cols_by_aggs};
use crate::utils::{columnize_expr, exprlist_to_fields, from_plan};
+use crate::{and, binary_expr, Operator};
use crate::{
logical_plan::{
Aggregate, Analyze, CrossJoin, EmptyRelation, Explain, Filter, Join,
@@ -27,7 +28,7 @@ use crate::{
Window,
},
utils::{
- expand_qualified_wildcard, expand_wildcard, expr_to_columns,
+ can_hash, expand_qualified_wildcard, expand_wildcard, expr_to_columns,
group_window_expr_by_sort_keys,
},
Expr, ExprSchemable, TableSource,
@@ -603,17 +604,46 @@ impl LogicalPlanBuilder {
let on: Vec<(_, _)> = left_keys.into_iter().zip(right_keys.into_iter()).collect();
let join_schema =
build_join_schema(self.plan.schema(), right.schema(), &join_type)?;
-
- Ok(Self::from(LogicalPlan::Join(Join {
- left: Arc::new(self.plan.clone()),
- right: Arc::new(right.clone()),
- on,
- filter: None,
- join_type,
- join_constraint: JoinConstraint::Using,
- schema: DFSchemaRef::new(join_schema),
- null_equals_null: false,
- })))
+ let mut join_on: Vec<(Column, Column)> = vec![];
+ let mut filters: Option<Expr> = None;
+ for (l, r) in &on {
+ if self.plan.schema().field_from_column(l).is_ok()
+ && right.schema().field_from_column(r).is_ok()
+ && can_hash(self.plan.schema().field_from_column(l).unwrap().data_type())
+ {
+ join_on.push((l.clone(), r.clone()));
+ } else if self.plan.schema().field_from_column(r).is_ok()
+ && right.schema().field_from_column(l).is_ok()
+ && can_hash(self.plan.schema().field_from_column(r).unwrap().data_type())
+ {
+ join_on.push((r.clone(), l.clone()));
+ } else {
+ let expr = binary_expr(
+ Expr::Column(l.clone()),
+ Operator::Eq,
+ Expr::Column(r.clone()),
+ );
+ match filters {
+ None => filters = Some(expr),
+ Some(filter_expr) => filters = Some(and(expr, filter_expr)),
+ }
+ }
+ }
+ if join_on.is_empty() {
+ let join = Self::from(self.plan.clone()).cross_join(&right.clone())?;
+ join.filter(filters.unwrap())
+ } else {
+ Ok(Self::from(LogicalPlan::Join(Join {
+ left: Arc::new(self.plan.clone()),
+ right: Arc::new(right.clone()),
+ on: join_on,
+ filter: filters,
+ join_type,
+ join_constraint: JoinConstraint::Using,
+ schema: DFSchemaRef::new(join_schema),
+ null_equals_null: false,
+ })))
+ }
}
/// Apply a cross join
diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs
index 065e6120b..2120acaed 100644
--- a/datafusion/expr/src/utils.rs
+++ b/datafusion/expr/src/utils.rs
@@ -25,6 +25,7 @@ use crate::logical_plan::{
Window,
};
use crate::{Expr, ExprSchemable, LogicalPlan, LogicalPlanBuilder};
+use arrow::datatypes::{DataType, TimeUnit};
use datafusion_common::{
Column, DFField, DFSchema, DFSchemaRef, DataFusionError, Result, ScalarValue,
};
@@ -640,6 +641,35 @@ pub fn expr_as_column_expr(expr: &Expr, plan: &LogicalPlan) -> Result<Expr> {
}
}
+/// can this data type be used in hash join equal conditions??
+/// data types here come from function 'equal_rows', if more data types are supported
+/// in equal_rows(hash join), add those data types here to generate join logical plan.
+pub fn can_hash(data_type: &DataType) -> bool {
+ match data_type {
+ DataType::Null => true,
+ DataType::Boolean => true,
+ DataType::Int8 => true,
+ DataType::Int16 => true,
+ DataType::Int32 => true,
+ DataType::Int64 => true,
+ DataType::UInt8 => true,
+ DataType::UInt16 => true,
+ DataType::UInt32 => true,
+ DataType::UInt64 => true,
+ DataType::Float32 => true,
+ DataType::Float64 => true,
+ DataType::Timestamp(time_unit, None) => match time_unit {
+ TimeUnit::Second => true,
+ TimeUnit::Millisecond => true,
+ TimeUnit::Microsecond => true,
+ TimeUnit::Nanosecond => true,
+ },
+ DataType::Utf8 => true,
+ DataType::LargeUtf8 => true,
+ _ => false,
+ }
+}
+
#[cfg(test)]
mod tests {
use super::*;
diff --git a/datafusion/sql/src/planner.rs b/datafusion/sql/src/planner.rs
index 2a33be32d..5978b77ee 100644
--- a/datafusion/sql/src/planner.rs
+++ b/datafusion/sql/src/planner.rs
@@ -29,8 +29,9 @@ use datafusion_expr::logical_plan::{
ToStringifiedPlan,
};
use datafusion_expr::utils::{
- expand_qualified_wildcard, expand_wildcard, expr_as_column_expr, expr_to_columns,
- find_aggregate_exprs, find_column_exprs, find_window_exprs, COUNT_STAR_EXPANSION,
+ can_hash, expand_qualified_wildcard, expand_wildcard, expr_as_column_expr,
+ expr_to_columns, find_aggregate_exprs, find_column_exprs, find_window_exprs,
+ COUNT_STAR_EXPANSION,
};
use datafusion_expr::{
and, col, lit, AggregateFunction, AggregateUDF, Expr, Operator, ScalarUDF,
@@ -594,7 +595,13 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
let mut filter = vec![];
// extract join keys
- extract_join_keys(expr, &mut keys, &mut filter);
+ extract_join_keys(
+ expr,
+ &mut keys,
+ &mut filter,
+ left.schema(),
+ right.schema(),
+ );
let (left_keys, right_keys): (Vec<Column>, Vec<Column>) =
keys.into_iter().unzip();
@@ -813,10 +820,22 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
for (l, r) in &possible_join_keys {
if left_schema.field_from_column(l).is_ok()
&& right_schema.field_from_column(r).is_ok()
+ && can_hash(
+ left_schema
+ .field_from_column(l)
+ .unwrap()
+ .data_type(),
+ )
{
join_keys.push((l.clone(), r.clone()));
} else if left_schema.field_from_column(r).is_ok()
&& right_schema.field_from_column(l).is_ok()
+ && can_hash(
+ left_schema
+ .field_from_column(r)
+ .unwrap()
+ .data_type(),
+ )
{
join_keys.push((r.clone(), l.clone()));
}
@@ -2512,12 +2531,26 @@ fn extract_join_keys(
expr: Expr,
accum: &mut Vec<(Column, Column)>,
accum_filter: &mut Vec<Expr>,
+ left_schema: &Arc<DFSchema>,
+ right_schema: &Arc<DFSchema>,
) {
match &expr {
Expr::BinaryExpr { left, op, right } => match op {
Operator::Eq => match (left.as_ref(), right.as_ref()) {
(Expr::Column(l), Expr::Column(r)) => {
- accum.push((l.clone(), r.clone()));
+ if left_schema.field_from_column(l).is_ok()
+ && right_schema.field_from_column(r).is_ok()
+ && can_hash(left_schema.field_from_column(l).unwrap().data_type())
+ {
+ accum.push((l.clone(), r.clone()));
+ } else if left_schema.field_from_column(r).is_ok()
+ && right_schema.field_from_column(l).is_ok()
+ && can_hash(left_schema.field_from_column(r).unwrap().data_type())
+ {
+ accum.push((r.clone(), l.clone()));
+ } else {
+ accum_filter.push(expr);
+ }
}
_other => {
accum_filter.push(expr);
@@ -2525,8 +2558,20 @@ fn extract_join_keys(
},
Operator::And => {
if let Expr::BinaryExpr { left, op: _, right } = expr {
- extract_join_keys(*left, accum, accum_filter);
- extract_join_keys(*right, accum, accum_filter);
+ extract_join_keys(
+ *left,
+ accum,
+ accum_filter,
+ left_schema,
+ right_schema,
+ );
+ extract_join_keys(
+ *right,
+ accum,
+ accum_filter,
+ left_schema,
+ right_schema,
+ );
}
}
_other => {