You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by al...@apache.org on 2022/06/22 13:31:17 UTC
[arrow-datafusion] branch master updated: Add additional data types are supported in hash join (#2721)
This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/master by this push:
new 80e6e7520 Add additional data types are supported in hash join (#2721)
80e6e7520 is described below
commit 80e6e75200bfd58a9b8fdaf5f4c3af7160eb9722
Author: AssHero <hu...@gmail.com>
AuthorDate: Wed Jun 22 21:31:12 2022 +0800
Add additional data types are supported in hash join (#2721)
* more data types are supported in hash join
* support decimal/dictionary data types in hashjoin
* add error messages
---
datafusion/core/src/physical_plan/hash_join.rs | 180 ++++++++++++++++++++++++-
datafusion/core/tests/sql/joins.rs | 169 +++++++++++++----------
datafusion/core/tests/sql/mod.rs | 72 ++++++++++
datafusion/expr/src/utils.rs | 8 ++
4 files changed, 358 insertions(+), 71 deletions(-)
diff --git a/datafusion/core/src/physical_plan/hash_join.rs b/datafusion/core/src/physical_plan/hash_join.rs
index 042d9525f..96c652f35 100644
--- a/datafusion/core/src/physical_plan/hash_join.rs
+++ b/datafusion/core/src/physical_plan/hash_join.rs
@@ -22,13 +22,17 @@ use ahash::RandomState;
use arrow::{
array::{
- ArrayData, ArrayRef, BooleanArray, Date32Array, Date64Array, LargeStringArray,
+ as_dictionary_array, as_string_array, ArrayData, ArrayRef, BooleanArray,
+ Date32Array, Date64Array, DecimalArray, DictionaryArray, LargeStringArray,
PrimitiveArray, TimestampMicrosecondArray, TimestampMillisecondArray,
TimestampSecondArray, UInt32BufferBuilder, UInt32Builder, UInt64BufferBuilder,
UInt64Builder,
},
compute,
- datatypes::{UInt32Type, UInt64Type},
+ datatypes::{
+ Int16Type, Int32Type, Int64Type, Int8Type, UInt16Type, UInt32Type, UInt64Type,
+ UInt8Type,
+ },
};
use smallvec::{smallvec, SmallVec};
use std::sync::Arc;
@@ -38,7 +42,7 @@ use std::{time::Instant, vec};
use futures::{ready, Stream, StreamExt, TryStreamExt};
use arrow::array::{as_boolean_array, new_null_array, Array};
-use arrow::datatypes::DataType;
+use arrow::datatypes::{ArrowNativeType, DataType};
use arrow::datatypes::{Schema, SchemaRef};
use arrow::error::Result as ArrowResult;
use arrow::record_batch::RecordBatch;
@@ -947,6 +951,58 @@ macro_rules! equal_rows_elem {
}};
}
+macro_rules! equal_rows_elem_with_string_dict {
+ ($key_array_type:ident, $l: ident, $r: ident, $left: ident, $right: ident, $null_equals_null: ident) => {{
+ let left_array: &DictionaryArray<$key_array_type> =
+ as_dictionary_array::<$key_array_type>($l);
+ let right_array: &DictionaryArray<$key_array_type> =
+ as_dictionary_array::<$key_array_type>($r);
+
+ let (left_values, left_values_index) = {
+ let keys_col = left_array.keys();
+ if keys_col.is_valid($left) {
+ let values_index = keys_col.value($left).to_usize().ok_or_else(|| {
+ DataFusionError::Internal(format!(
+ "Can not convert index to usize in dictionary of type creating group by value {:?}",
+ keys_col.data_type()
+ ))
+ });
+
+ match values_index {
+ Ok(index) => (as_string_array(left_array.values()), Some(index)),
+ _ => (as_string_array(left_array.values()), None)
+ }
+ } else {
+ (as_string_array(left_array.values()), None)
+ }
+ };
+ let (right_values, right_values_index) = {
+ let keys_col = right_array.keys();
+ if keys_col.is_valid($right) {
+ let values_index = keys_col.value($right).to_usize().ok_or_else(|| {
+ DataFusionError::Internal(format!(
+ "Can not convert index to usize in dictionary of type creating group by value {:?}",
+ keys_col.data_type()
+ ))
+ });
+
+ match values_index {
+ Ok(index) => (as_string_array(right_array.values()), Some(index)),
+ _ => (as_string_array(right_array.values()), None)
+ }
+ } else {
+ (as_string_array(right_array.values()), None)
+ }
+ };
+
+ match (left_values_index, right_values_index) {
+ (Some(left_values_index), Some(right_values_index)) => left_values.value(left_values_index) == right_values.value(right_values_index),
+ (None, None) => $null_equals_null,
+ _ => false,
+ }
+ }};
+}
+
/// Left and right row have equal values
/// If more data types are supported here, please also add the data types in can_hash function
/// to generate hash join logical plan.
@@ -1054,6 +1110,124 @@ fn equal_rows(
DataType::LargeUtf8 => {
equal_rows_elem!(LargeStringArray, l, r, left, right, null_equals_null)
}
+ DataType::Decimal(_, lscale) => match r.data_type() {
+ DataType::Decimal(_, rscale) => {
+ if lscale == rscale {
+ equal_rows_elem!(
+ DecimalArray,
+ l,
+ r,
+ left,
+ right,
+ null_equals_null
+ )
+ } else {
+ err = Some(Err(DataFusionError::Internal(
+ "Inconsistent Decimal data type in hasher, the scale should be same".to_string(),
+ )));
+ false
+ }
+ }
+ _ => {
+ err = Some(Err(DataFusionError::Internal(
+ "Unsupported data type in hasher".to_string(),
+ )));
+ false
+ }
+ },
+ DataType::Dictionary(key_type, value_type)
+ if *value_type.as_ref() == DataType::Utf8 =>
+ {
+ match key_type.as_ref() {
+ DataType::Int8 => {
+ equal_rows_elem_with_string_dict!(
+ Int8Type,
+ l,
+ r,
+ left,
+ right,
+ null_equals_null
+ )
+ }
+ DataType::Int16 => {
+ equal_rows_elem_with_string_dict!(
+ Int16Type,
+ l,
+ r,
+ left,
+ right,
+ null_equals_null
+ )
+ }
+ DataType::Int32 => {
+ equal_rows_elem_with_string_dict!(
+ Int32Type,
+ l,
+ r,
+ left,
+ right,
+ null_equals_null
+ )
+ }
+ DataType::Int64 => {
+ equal_rows_elem_with_string_dict!(
+ Int64Type,
+ l,
+ r,
+ left,
+ right,
+ null_equals_null
+ )
+ }
+ DataType::UInt8 => {
+ equal_rows_elem_with_string_dict!(
+ UInt8Type,
+ l,
+ r,
+ left,
+ right,
+ null_equals_null
+ )
+ }
+ DataType::UInt16 => {
+ equal_rows_elem_with_string_dict!(
+ UInt16Type,
+ l,
+ r,
+ left,
+ right,
+ null_equals_null
+ )
+ }
+ DataType::UInt32 => {
+ equal_rows_elem_with_string_dict!(
+ UInt32Type,
+ l,
+ r,
+ left,
+ right,
+ null_equals_null
+ )
+ }
+ DataType::UInt64 => {
+ equal_rows_elem_with_string_dict!(
+ UInt64Type,
+ l,
+ r,
+ left,
+ right,
+ null_equals_null
+ )
+ }
+ _ => {
+ // should not happen
+ err = Some(Err(DataFusionError::Internal(
+ "Unsupported data type in hasher".to_string(),
+ )));
+ false
+ }
+ }
+ }
other => {
// This is internal because we should have caught this before.
err = Some(Err(DataFusionError::Internal(format!(
diff --git a/datafusion/core/tests/sql/joins.rs b/datafusion/core/tests/sql/joins.rs
index 6b3b8c339..0dd948ca6 100644
--- a/datafusion/core/tests/sql/joins.rs
+++ b/datafusion/core/tests/sql/joins.rs
@@ -1206,29 +1206,11 @@ async fn join_partitioned() -> Result<()> {
}
#[tokio::test]
-async fn join_with_hash_unsupported_data_type() -> Result<()> {
- let ctx = SessionContext::new();
-
- let schema = Schema::new(vec![
- Field::new("c1", DataType::Int32, true),
- Field::new("c2", DataType::Utf8, true),
- Field::new("c3", DataType::Int64, true),
- Field::new("c4", DataType::Date32, true),
- ]);
- let data = RecordBatch::try_new(
- Arc::new(schema),
- vec![
- Arc::new(Int32Array::from_slice(&[1, 2, 3])),
- Arc::new(StringArray::from_slice(&["aaa", "bbb", "ccc"])),
- Arc::new(Int64Array::from_slice(&[100, 200, 300])),
- Arc::new(Date32Array::from(vec![Some(1), Some(2), Some(3)])),
- ],
- )?;
- let table = MemTable::try_new(data.schema(), vec![vec![data]])?;
- ctx.register_table("foo", Arc::new(table))?;
+async fn hash_join_with_date32() -> Result<()> {
+ let ctx = create_hashjoin_datatype_context()?;
- // join on hash unsupported data type (Date32), use cross join instead hash join
- let sql = "select * from foo t1 join foo t2 on t1.c4 = t2.c4";
+ // inner join on hash supported data type (Date32)
+ let sql = "select * from t1 join t2 on t1.c1 = t2.c1";
let msg = format!("Creating logical plan for '{}'", sql);
let plan = ctx
.create_logical_plan(&("explain ".to_owned() + sql))
@@ -1237,13 +1219,10 @@ async fn join_with_hash_unsupported_data_type() -> Result<()> {
let plan = state.optimize(&plan)?;
let expected = vec![
"Explain [plan_type:Utf8, plan:Utf8]",
- " Projection: #t1.c1, #t1.c2, #t1.c3, #t1.c4, #t2.c1, #t2.c2, #t2.c3, #t2.c4 [c1:Int32;N, c2:Utf8;N, c3:Int64;N, c4:Date32;N, c1:Int32;N, c2:Utf8;N, c3:Int64;N, c4:Date32;N]",
- " Filter: #t1.c4 = #t2.c4 [c1:Int32;N, c2:Utf8;N, c3:Int64;N, c4:Date32;N, c1:Int32;N, c2:Utf8;N, c3:Int64;N, c4:Date32;N]",
- " CrossJoin: [c1:Int32;N, c2:Utf8;N, c3:Int64;N, c4:Date32;N, c1:Int32;N, c2:Utf8;N, c3:Int64;N, c4:Date32;N]",
- " SubqueryAlias: t1 [c1:Int32;N, c2:Utf8;N, c3:Int64;N, c4:Date32;N]",
- " TableScan: foo projection=Some([c1, c2, c3, c4]) [c1:Int32;N, c2:Utf8;N, c3:Int64;N, c4:Date32;N]",
- " SubqueryAlias: t2 [c1:Int32;N, c2:Utf8;N, c3:Int64;N, c4:Date32;N]",
- " TableScan: foo projection=Some([c1, c2, c3, c4]) [c1:Int32;N, c2:Utf8;N, c3:Int64;N, c4:Date32;N]",
+ " Projection: #t1.c1, #t1.c2, #t1.c3, #t1.c4, #t2.c1, #t2.c2, #t2.c3, #t2.c4 [c1:Date32;N, c2:Date64;N, c3:Decimal(5, 2);N, c4:Dictionary(Int32, Utf8);N, c1:Date32;N, c2:Date64;N, c3:Decimal(10, 2);N, c4:Dictionary(Int32, Utf8);N]",
+ " Inner Join: #t1.c1 = #t2.c1 [c1:Date32;N, c2:Date64;N, c3:Decimal(5, 2);N, c4:Dictionary(Int32, Utf8);N, c1:Date32;N, c2:Date64;N, c3:Decimal(10, 2);N, c4:Dictionary(Int32, Utf8);N]",
+ " TableScan: t1 projection=Some([c1, c2, c3, c4]) [c1:Date32;N, c2:Date64;N, c3:Decimal(5, 2);N, c4:Dictionary(Int32, Utf8);N]",
+ " TableScan: t2 projection=Some([c1, c2, c3, c4]) [c1:Date32;N, c2:Date64;N, c3:Decimal(10, 2);N, c4:Dictionary(Int32, Utf8);N]",
];
let formatted = plan.display_indent_schema().to_string();
let actual: Vec<&str> = formatted.trim().lines().collect();
@@ -1254,32 +1233,38 @@ async fn join_with_hash_unsupported_data_type() -> Result<()> {
);
let expected = vec![
- "+----+-----+-----+------------+----+-----+-----+------------+",
- "| c1 | c2 | c3 | c4 | c1 | c2 | c3 | c4 |",
- "+----+-----+-----+------------+----+-----+-----+------------+",
- "| 1 | aaa | 100 | 1970-01-02 | 1 | aaa | 100 | 1970-01-02 |",
- "| 2 | bbb | 200 | 1970-01-03 | 2 | bbb | 200 | 1970-01-03 |",
- "| 3 | ccc | 300 | 1970-01-04 | 3 | ccc | 300 | 1970-01-04 |",
- "+----+-----+-----+------------+----+-----+-----+------------+",
+ "+------------+------------+---------+-----+------------+------------+---------+-----+",
+ "| c1 | c2 | c3 | c4 | c1 | c2 | c3 | c4 |",
+ "+------------+------------+---------+-----+------------+------------+---------+-----+",
+ "| 1970-01-02 | 1970-01-02 | 1.23 | abc | 1970-01-02 | 1970-01-02 | -123.12 | abc |",
+ "| 1970-01-04 | | -123.12 | jkl | 1970-01-04 | | 789.00 | |",
+ "+------------+------------+---------+-----+------------+------------+---------+-----+",
];
let results = execute_to_batches(&ctx, sql).await;
assert_batches_sorted_eq!(expected, &results);
- // join on hash supported data type (Int32), use hash join
- let sql = "select * from foo t1 join foo t2 on t1.c1 = t2.c1";
+ Ok(())
+}
+
+#[tokio::test]
+async fn hash_join_with_date64() -> Result<()> {
+ let ctx = create_hashjoin_datatype_context()?;
+
+ // left join on hash supported data type (Date64)
+ let sql = "select * from t1 left join t2 on t1.c2 = t2.c2";
+ let msg = format!("Creating logical plan for '{}'", sql);
let plan = ctx
.create_logical_plan(&("explain ".to_owned() + sql))
.expect(&msg);
+ let state = ctx.state();
let plan = state.optimize(&plan)?;
let expected = vec![
"Explain [plan_type:Utf8, plan:Utf8]",
- " Projection: #t1.c1, #t1.c2, #t1.c3, #t1.c4, #t2.c1, #t2.c2, #t2.c3, #t2.c4 [c1:Int32;N, c2:Utf8;N, c3:Int64;N, c4:Date32;N, c1:Int32;N, c2:Utf8;N, c3:Int64;N, c4:Date32;N]",
- " Inner Join: #t1.c1 = #t2.c1 [c1:Int32;N, c2:Utf8;N, c3:Int64;N, c4:Date32;N, c1:Int32;N, c2:Utf8;N, c3:Int64;N, c4:Date32;N]",
- " SubqueryAlias: t1 [c1:Int32;N, c2:Utf8;N, c3:Int64;N, c4:Date32;N]",
- " TableScan: foo projection=Some([c1, c2, c3, c4]) [c1:Int32;N, c2:Utf8;N, c3:Int64;N, c4:Date32;N]",
- " SubqueryAlias: t2 [c1:Int32;N, c2:Utf8;N, c3:Int64;N, c4:Date32;N]",
- " TableScan: foo projection=Some([c1, c2, c3, c4]) [c1:Int32;N, c2:Utf8;N, c3:Int64;N, c4:Date32;N]",
+ " Projection: #t1.c1, #t1.c2, #t1.c3, #t1.c4, #t2.c1, #t2.c2, #t2.c3, #t2.c4 [c1:Date32;N, c2:Date64;N, c3:Decimal(5, 2);N, c4:Dictionary(Int32, Utf8);N, c1:Date32;N, c2:Date64;N, c3:Decimal(10, 2);N, c4:Dictionary(Int32, Utf8);N]",
+ " Left Join: #t1.c2 = #t2.c2 [c1:Date32;N, c2:Date64;N, c3:Decimal(5, 2);N, c4:Dictionary(Int32, Utf8);N, c1:Date32;N, c2:Date64;N, c3:Decimal(10, 2);N, c4:Dictionary(Int32, Utf8);N]",
+ " TableScan: t1 projection=Some([c1, c2, c3, c4]) [c1:Date32;N, c2:Date64;N, c3:Decimal(5, 2);N, c4:Dictionary(Int32, Utf8);N]",
+ " TableScan: t2 projection=Some([c1, c2, c3, c4]) [c1:Date32;N, c2:Date64;N, c3:Decimal(10, 2);N, c4:Dictionary(Int32, Utf8);N]",
];
let formatted = plan.display_indent_schema().to_string();
let actual: Vec<&str> = formatted.trim().lines().collect();
@@ -1290,34 +1275,84 @@ async fn join_with_hash_unsupported_data_type() -> Result<()> {
);
let expected = vec![
- "+----+-----+-----+------------+----+-----+-----+------------+",
- "| c1 | c2 | c3 | c4 | c1 | c2 | c3 | c4 |",
- "+----+-----+-----+------------+----+-----+-----+------------+",
- "| 1 | aaa | 100 | 1970-01-02 | 1 | aaa | 100 | 1970-01-02 |",
- "| 2 | bbb | 200 | 1970-01-03 | 2 | bbb | 200 | 1970-01-03 |",
- "| 3 | ccc | 300 | 1970-01-04 | 3 | ccc | 300 | 1970-01-04 |",
- "+----+-----+-----+------------+----+-----+-----+------------+",
+ "+------------+------------+---------+-----+------------+------------+---------+--------+",
+ "| c1 | c2 | c3 | c4 | c1 | c2 | c3 | c4 |",
+ "+------------+------------+---------+-----+------------+------------+---------+--------+",
+ "| | 1970-01-04 | 789.00 | ghi | | 1970-01-04 | 0.00 | qwerty |",
+ "| 1970-01-02 | 1970-01-02 | 1.23 | abc | 1970-01-02 | 1970-01-02 | -123.12 | abc |",
+ "| 1970-01-03 | 1970-01-03 | 456.00 | def | | | | |",
+ "| 1970-01-04 | | -123.12 | jkl | | | | |",
+ "+------------+------------+---------+-----+------------+------------+---------+--------+",
];
let results = execute_to_batches(&ctx, sql).await;
assert_batches_sorted_eq!(expected, &results);
- // join on two columns, hash supported data type(Int64) and hash unsupported data type (Date32),
- // use hash join on Int64 column, and filter on Date32 column.
- let sql = "select * from foo t1, foo t2 where t1.c3 = t2.c3 and t1.c4 = t2.c4";
+ Ok(())
+}
+
+#[tokio::test]
+async fn hash_join_with_decimal() -> Result<()> {
+ let ctx = create_hashjoin_datatype_context()?;
+
+ // right join on hash supported data type (Decimal)
+ let sql = "select * from t1 right join t2 on t1.c3 = t2.c3";
+ let msg = format!("Creating logical plan for '{}'", sql);
let plan = ctx
.create_logical_plan(&("explain ".to_owned() + sql))
.expect(&msg);
+ let state = ctx.state();
+ let plan = state.optimize(&plan)?;
+ let expected = vec![
+ "Explain [plan_type:Utf8, plan:Utf8]",
+ " Projection: #t1.c1, #t1.c2, #t1.c3, #t1.c4, #t2.c1, #t2.c2, #t2.c3, #t2.c4 [c1:Date32;N, c2:Date64;N, c3:Decimal(5, 2);N, c4:Dictionary(Int32, Utf8);N, c1:Date32;N, c2:Date64;N, c3:Decimal(10, 2);N, c4:Dictionary(Int32, Utf8);N]",
+ " Right Join: #t1.c3 = #t2.c3 [c1:Date32;N, c2:Date64;N, c3:Decimal(5, 2);N, c4:Dictionary(Int32, Utf8);N, c1:Date32;N, c2:Date64;N, c3:Decimal(10, 2);N, c4:Dictionary(Int32, Utf8);N]",
+ " TableScan: t1 projection=Some([c1, c2, c3, c4]) [c1:Date32;N, c2:Date64;N, c3:Decimal(5, 2);N, c4:Dictionary(Int32, Utf8);N]",
+ " TableScan: t2 projection=Some([c1, c2, c3, c4]) [c1:Date32;N, c2:Date64;N, c3:Decimal(10, 2);N, c4:Dictionary(Int32, Utf8);N]",
+ ];
+ let formatted = plan.display_indent_schema().to_string();
+ let actual: Vec<&str> = formatted.trim().lines().collect();
+ assert_eq!(
+ expected, actual,
+ "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n",
+ expected, actual
+ );
+
+ let expected = vec![
+ "+------------+------------+---------+-----+------------+------------+-----------+---------+",
+ "| c1 | c2 | c3 | c4 | c1 | c2 | c3 | c4 |",
+ "+------------+------------+---------+-----+------------+------------+-----------+---------+",
+ "| | | | | | | 100000.00 | abcdefg |",
+ "| | | | | | 1970-01-04 | 0.00 | qwerty |",
+ "| | 1970-01-04 | 789.00 | ghi | 1970-01-04 | | 789.00 | |",
+ "| 1970-01-04 | | -123.12 | jkl | 1970-01-02 | 1970-01-02 | -123.12 | abc |",
+ "+------------+------------+---------+-----+------------+------------+-----------+---------+",
+ ];
+
+ let results = execute_to_batches(&ctx, sql).await;
+ assert_batches_sorted_eq!(expected, &results);
+
+ Ok(())
+}
+
+#[tokio::test]
+async fn hash_join_with_dictionary() -> Result<()> {
+ let ctx = create_hashjoin_datatype_context()?;
+
+ // inner join on hash supported data type (Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)))
+ let sql = "select * from t1 join t2 on t1.c4 = t2.c4";
+ let msg = format!("Creating logical plan for '{}'", sql);
+ let plan = ctx
+ .create_logical_plan(&("explain ".to_owned() + sql))
+ .expect(&msg);
+ let state = ctx.state();
let plan = state.optimize(&plan)?;
let expected = vec![
"Explain [plan_type:Utf8, plan:Utf8]",
- " Projection: #t1.c1, #t1.c2, #t1.c3, #t1.c4, #t2.c1, #t2.c2, #t2.c3, #t2.c4 [c1:Int32;N, c2:Utf8;N, c3:Int64;N, c4:Date32;N, c1:Int32;N, c2:Utf8;N, c3:Int64;N, c4:Date32;N]",
- " Filter: #t1.c4 = #t2.c4 [c1:Int32;N, c2:Utf8;N, c3:Int64;N, c4:Date32;N, c1:Int32;N, c2:Utf8;N, c3:Int64;N, c4:Date32;N]",
- " Inner Join: #t1.c3 = #t2.c3 [c1:Int32;N, c2:Utf8;N, c3:Int64;N, c4:Date32;N, c1:Int32;N, c2:Utf8;N, c3:Int64;N, c4:Date32;N]",
- " SubqueryAlias: t1 [c1:Int32;N, c2:Utf8;N, c3:Int64;N, c4:Date32;N]",
- " TableScan: foo projection=Some([c1, c2, c3, c4]) [c1:Int32;N, c2:Utf8;N, c3:Int64;N, c4:Date32;N]",
- " SubqueryAlias: t2 [c1:Int32;N, c2:Utf8;N, c3:Int64;N, c4:Date32;N]",
- " TableScan: foo projection=Some([c1, c2, c3, c4]) [c1:Int32;N, c2:Utf8;N, c3:Int64;N, c4:Date32;N]",
+ " Projection: #t1.c1, #t1.c2, #t1.c3, #t1.c4, #t2.c1, #t2.c2, #t2.c3, #t2.c4 [c1:Date32;N, c2:Date64;N, c3:Decimal(5, 2);N, c4:Dictionary(Int32, Utf8);N, c1:Date32;N, c2:Date64;N, c3:Decimal(10, 2);N, c4:Dictionary(Int32, Utf8);N]",
+ " Inner Join: #t1.c4 = #t2.c4 [c1:Date32;N, c2:Date64;N, c3:Decimal(5, 2);N, c4:Dictionary(Int32, Utf8);N, c1:Date32;N, c2:Date64;N, c3:Decimal(10, 2);N, c4:Dictionary(Int32, Utf8);N]",
+ " TableScan: t1 projection=Some([c1, c2, c3, c4]) [c1:Date32;N, c2:Date64;N, c3:Decimal(5, 2);N, c4:Dictionary(Int32, Utf8);N]",
+ " TableScan: t2 projection=Some([c1, c2, c3, c4]) [c1:Date32;N, c2:Date64;N, c3:Decimal(10, 2);N, c4:Dictionary(Int32, Utf8);N]",
];
let formatted = plan.display_indent_schema().to_string();
let actual: Vec<&str> = formatted.trim().lines().collect();
@@ -1328,13 +1363,11 @@ async fn join_with_hash_unsupported_data_type() -> Result<()> {
);
let expected = vec![
- "+----+-----+-----+------------+----+-----+-----+------------+",
- "| c1 | c2 | c3 | c4 | c1 | c2 | c3 | c4 |",
- "+----+-----+-----+------------+----+-----+-----+------------+",
- "| 1 | aaa | 100 | 1970-01-02 | 1 | aaa | 100 | 1970-01-02 |",
- "| 2 | bbb | 200 | 1970-01-03 | 2 | bbb | 200 | 1970-01-03 |",
- "| 3 | ccc | 300 | 1970-01-04 | 3 | ccc | 300 | 1970-01-04 |",
- "+----+-----+-----+------------+----+-----+-----+------------+",
+ "+------------+------------+------+-----+------------+------------+---------+-----+",
+ "| c1 | c2 | c3 | c4 | c1 | c2 | c3 | c4 |",
+ "+------------+------------+------+-----+------------+------------+---------+-----+",
+ "| 1970-01-02 | 1970-01-02 | 1.23 | abc | 1970-01-02 | 1970-01-02 | -123.12 | abc |",
+ "+------------+------------+------+-----+------------+------------+---------+-----+",
];
let results = execute_to_batches(&ctx, sql).await;
diff --git a/datafusion/core/tests/sql/mod.rs b/datafusion/core/tests/sql/mod.rs
index 3e19dbcb9..0e3e08873 100644
--- a/datafusion/core/tests/sql/mod.rs
+++ b/datafusion/core/tests/sql/mod.rs
@@ -262,6 +262,78 @@ fn create_join_context_qualified() -> Result<SessionContext> {
Ok(ctx)
}
+fn create_hashjoin_datatype_context() -> Result<SessionContext> {
+ let ctx = SessionContext::new();
+
+ let t1_schema = Schema::new(vec![
+ Field::new("c1", DataType::Date32, true),
+ Field::new("c2", DataType::Date64, true),
+ Field::new("c3", DataType::Decimal(5, 2), true),
+ Field::new(
+ "c4",
+ DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)),
+ true,
+ ),
+ ]);
+ let dict1: DictionaryArray<Int32Type> =
+ vec!["abc", "def", "ghi", "jkl"].into_iter().collect();
+ let t1_data = RecordBatch::try_new(
+ Arc::new(t1_schema),
+ vec![
+ Arc::new(Date32Array::from(vec![Some(1), Some(2), None, Some(3)])),
+ Arc::new(Date64Array::from(vec![
+ Some(86400000),
+ Some(172800000),
+ Some(259200000),
+ None,
+ ])),
+ Arc::new(
+ DecimalArray::from_iter_values([123, 45600, 78900, -12312])
+ .with_precision_and_scale(5, 2)
+ .unwrap(),
+ ),
+ Arc::new(dict1),
+ ],
+ )?;
+ let table = MemTable::try_new(t1_data.schema(), vec![vec![t1_data]])?;
+ ctx.register_table("t1", Arc::new(table))?;
+
+ let t2_schema = Schema::new(vec![
+ Field::new("c1", DataType::Date32, true),
+ Field::new("c2", DataType::Date64, true),
+ Field::new("c3", DataType::Decimal(10, 2), true),
+ Field::new(
+ "c4",
+ DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)),
+ true,
+ ),
+ ]);
+ let dict2: DictionaryArray<Int32Type> =
+ vec!["abc", "abcdefg", "qwerty", ""].into_iter().collect();
+ let t2_data = RecordBatch::try_new(
+ Arc::new(t2_schema),
+ vec![
+ Arc::new(Date32Array::from(vec![Some(1), None, None, Some(3)])),
+ Arc::new(Date64Array::from(vec![
+ Some(86400000),
+ None,
+ Some(259200000),
+ None,
+ ])),
+ Arc::new(
+ DecimalArray::from_iter_values([-12312, 10000000, 0, 78900])
+ .with_precision_and_scale(10, 2)
+ .unwrap(),
+ ),
+ Arc::new(dict2),
+ ],
+ )?;
+ let table = MemTable::try_new(t2_data.schema(), vec![vec![t2_data]])?;
+ ctx.register_table("t2", Arc::new(table))?;
+
+ Ok(ctx)
+}
+
/// the table column_left has more rows than the table column_right
fn create_join_context_unbalanced(
column_left: &str,
diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs
index a85a817a8..75180189a 100644
--- a/datafusion/expr/src/utils.rs
+++ b/datafusion/expr/src/utils.rs
@@ -682,6 +682,14 @@ pub fn can_hash(data_type: &DataType) -> bool {
},
DataType::Utf8 => true,
DataType::LargeUtf8 => true,
+ DataType::Decimal(_, _) => true,
+ DataType::Date32 => true,
+ DataType::Date64 => true,
+ DataType::Dictionary(key_type, value_type)
+ if *value_type.as_ref() == DataType::Utf8 =>
+ {
+ DataType::is_dictionary_key_type(key_type)
+ }
_ => false,
}
}