You are viewing a plain text version of this content. The canonical link for it is here.
Posted to github@arrow.apache.org by GitBox <gi...@apache.org> on 2022/06/19 11:49:27 UTC

[GitHub] [arrow-datafusion] alamb commented on a diff in pull request #2721: Add additional data types are supported in hash join

alamb commented on code in PR #2721:
URL: https://github.com/apache/arrow-datafusion/pull/2721#discussion_r901089436


##########
datafusion/expr/src/utils.rs:
##########
@@ -682,6 +682,24 @@ pub fn can_hash(data_type: &DataType) -> bool {
         },
         DataType::Utf8 => true,
         DataType::LargeUtf8 => true,
+        DataType::Decimal(_, _) => true,
+        DataType::Date32 => true,
+        DataType::Date64 => true,
+        DataType::Dictionary(key_type, value_type)
+            if *value_type.as_ref() == DataType::Utf8 =>
+        {
+            matches!(

Review Comment:
   minor comment: could potentially use `DataType::is_dictionary_key_type` here: https://docs.rs/arrow/16.0.0/arrow/datatypes/enum.DataType.html#method.is_dictionary_key_type



##########
datafusion/core/tests/sql/joins.rs:
##########
@@ -1206,29 +1206,83 @@ async fn join_partitioned() -> Result<()> {
 }
 
 #[tokio::test]
-async fn join_with_hash_unsupported_data_type() -> Result<()> {
+async fn join_with_hash_supported_data_type() -> Result<()> {
     let ctx = SessionContext::new();
 
-    let schema = Schema::new(vec![
-        Field::new("c1", DataType::Int32, true),
-        Field::new("c2", DataType::Utf8, true),
-        Field::new("c3", DataType::Int64, true),
-        Field::new("c4", DataType::Date32, true),
+    let t1_schema = Schema::new(vec![
+        Field::new("c1", DataType::Date32, true),
+        Field::new("c2", DataType::Date64, true),
+        Field::new("c3", DataType::Decimal(5, 2), true),
+        Field::new(
+            "c4",
+            DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)),
+            true,
+        ),
+        Field::new("c5", DataType::Binary, true),
     ]);
-    let data = RecordBatch::try_new(
-        Arc::new(schema),
+    let dict1: DictionaryArray<Int32Type> =
+        vec!["abc", "def", "ghi", "jkl"].into_iter().collect();
+    let binary_value1: Vec<&[u8]> = vec![b"one", b"two", b"", b"three"];
+    let t1_data = RecordBatch::try_new(
+        Arc::new(t1_schema),
+        vec![
+            Arc::new(Date32Array::from(vec![Some(1), Some(2), None, Some(3)])),
+            Arc::new(Date64Array::from(vec![
+                Some(86400000),
+                Some(172800000),
+                Some(259200000),
+                None,
+            ])),
+            Arc::new(
+                DecimalArray::from_iter_values([123, 45600, 78900, -12312])
+                    .with_precision_and_scale(5, 2)
+                    .unwrap(),
+            ),
+            Arc::new(dict1),
+            Arc::new(BinaryArray::from_vec(binary_value1)),
+        ],
+    )?;
+    let table = MemTable::try_new(t1_data.schema(), vec![vec![t1_data]])?;
+    ctx.register_table("t1", Arc::new(table))?;
+
+    let t2_schema = Schema::new(vec![
+        Field::new("c1", DataType::Date32, true),
+        Field::new("c2", DataType::Date64, true),
+        Field::new("c3", DataType::Decimal(10, 2), true),
+        Field::new(
+            "c4",
+            DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)),
+            true,
+        ),
+        Field::new("c5", DataType::Binary, true),
+    ]);
+    let dict2: DictionaryArray<Int32Type> =
+        vec!["abc", "abcdefg", "qwerty", ""].into_iter().collect();
+    let binary_value2: Vec<&[u8]> = vec![b"one", b"", b"two", b"three"];
+    let t2_data = RecordBatch::try_new(
+        Arc::new(t2_schema),
         vec![
-            Arc::new(Int32Array::from_slice(&[1, 2, 3])),
-            Arc::new(StringArray::from_slice(&["aaa", "bbb", "ccc"])),
-            Arc::new(Int64Array::from_slice(&[100, 200, 300])),
-            Arc::new(Date32Array::from(vec![Some(1), Some(2), Some(3)])),
+            Arc::new(Date32Array::from(vec![Some(1), None, None, Some(3)])),
+            Arc::new(Date64Array::from(vec![
+                Some(86400000),
+                None,
+                Some(259200000),
+                None,
+            ])),
+            Arc::new(
+                DecimalArray::from_iter_values([-12312, 10000000, 0, 78900])
+                    .with_precision_and_scale(10, 2)
+                    .unwrap(),
+            ),
+            Arc::new(dict2),
+            Arc::new(BinaryArray::from_vec(binary_value2)),
         ],
     )?;
-    let table = MemTable::try_new(data.schema(), vec![vec![data]])?;
-    ctx.register_table("foo", Arc::new(table))?;
+    let table = MemTable::try_new(t2_data.schema(), vec![vec![t2_data]])?;
+    ctx.register_table("t2", Arc::new(table))?;
 
-    // join on hash unsupported data type (Date32), use cross join instead hash join
-    let sql = "select * from foo t1 join foo t2 on t1.c4 = t2.c4";
+    // inner join on hash supported data type (Date32)

Review Comment:
   As a stylistic thing, I find it easier to debug / work with tests if there are fewer queries in the individual test.
   
   So in this case, perhaps something like
   
   
   ```rust
   #[test]
   fn hash_join_date32() {
       let ctx = make_test_hash_join_ctx();
   
       let sql = "select * from t1 join t2 on t1.c1 = t2.c1";
       let msg = format!("Creating logical plan for '{}'", sql);
       let plan = ctx
           .create_logical_plan(&("explain ".to_owned() + sql))
           .expect(&msg);
       let state = ctx.state();
       let plan = state.optimize(&plan)?;
       let expected = vec![
           "Explain [plan_type:Utf8, plan:Utf8]",
           "  Projection: #t1.c1, #t1.c2, #t1.c3, #t1.c4, #t1.c5, #t2.c1, #t2.c2, #t2.c3, #t2.c4, #t2.c5 [c1:Date32;N, c2:Date64;N, c3:Decimal(5, 2);N, c4:Dictionary(Int32, Utf8);N, c5:Binary;N, c1:Date32;N, c2:Date64;N, c3:Decimal(10, 2);N, c4:Dictionary(Int32, Utf8);N, c5:Binary;N]",
           "    Inner Join: #t1.c1 = #t2.c1 [c1:Date32;N, c2:Date64;N, c3:Decimal(5, 2);N, c4:Dictionary(Int32, Utf8);N, c5:Binary;N, c1:Date32;N, c2:Date64;N, c3:Decimal(10, 2);N, c4:Dictionary(Int32, Utf8);N, c5:Binary;N]",
           "      TableScan: t1 projection=Some([c1, c2, c3, c4, c5]) [c1:Date32;N, c2:Date64;N, c3:Decimal(5, 2);N, c4:Dictionary(Int32, Utf8);N, c5:Binary;N]",
           "      TableScan: t2 projection=Some([c1, c2, c3, c4, c5]) [c1:Date32;N, c2:Date64;N, c3:Decimal(10, 2);N, c4:Dictionary(Int32, Utf8);N, c5:Binary;N]",
       ];
       let formatted = plan.display_indent_schema().to_string();
       let actual: Vec<&str> = formatted.trim().lines().collect();
       assert_eq!(
           expected, actual,
           "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n",
           expected, actual
       );
   
       let expected = vec![
           "+------------+------------+---------+-----+------------+------------+------------+---------+-----+------------+",
           "| c1         | c2         | c3      | c4  | c5         | c1         | c2         | c3      | c4  | c5         |",
           "+------------+------------+---------+-----+------------+------------+------------+---------+-----+------------+",
           "| 1970-01-02 | 1970-01-02 | 1.23    | abc | 6f6e65     | 1970-01-02 | 1970-01-02 | -123.12 | abc | 6f6e65     |",
           "| 1970-01-04 |            | -123.12 | jkl | 7468726565 | 1970-01-04 |            | 789.00  |     | 7468726565 |",
           "+------------+------------+---------+-----+------------+------------+------------+---------+-----+------------+",
       ];
   
       let results = execute_to_batches(&ctx, sql).await;
       assert_batches_sorted_eq!(expected, &results);
   ```
   
   (that way you can get multiple expected tests with one `cargo test` run rather than having to run it several times)
   



##########
datafusion/core/src/physical_plan/hash_join.rs:
##########
@@ -1054,6 +1079,102 @@ fn equal_rows(
             DataType::LargeUtf8 => {
                 equal_rows_elem!(LargeStringArray, l, r, left, right, null_equals_null)
             }
+            DataType::Decimal(_, _) => {

Review Comment:
   For decimal, I wonder if we also need to ensure that the precision and scale are the same (e.g. `l.data_type() == r.data_type()`) 🤔 



##########
datafusion/core/tests/sql/joins.rs:
##########
@@ -1254,32 +1373,31 @@ async fn join_with_hash_unsupported_data_type() -> Result<()> {
     );
 
     let expected = vec![
-        "+----+-----+-----+------------+----+-----+-----+------------+",
-        "| c1 | c2  | c3  | c4         | c1 | c2  | c3  | c4         |",
-        "+----+-----+-----+------------+----+-----+-----+------------+",
-        "| 1  | aaa | 100 | 1970-01-02 | 1  | aaa | 100 | 1970-01-02 |",
-        "| 2  | bbb | 200 | 1970-01-03 | 2  | bbb | 200 | 1970-01-03 |",
-        "| 3  | ccc | 300 | 1970-01-04 | 3  | ccc | 300 | 1970-01-04 |",
-        "+----+-----+-----+------------+----+-----+-----+------------+",
+        "+------------+------------+---------+-----+------------+------------+------------+-----------+---------+------------+",
+        "| c1         | c2         | c3      | c4  | c5         | c1         | c2         | c3        | c4      | c5         |",
+        "+------------+------------+---------+-----+------------+------------+------------+-----------+---------+------------+",
+        "|            |            |         |     |            |            |            | 100000.00 | abcdefg |            |",
+        "|            |            |         |     |            |            | 1970-01-04 | 0.00      | qwerty  | 74776f     |",
+        "|            | 1970-01-04 | 789.00  | ghi |            | 1970-01-04 |            | 789.00    |         | 7468726565 |",
+        "| 1970-01-04 |            | -123.12 | jkl | 7468726565 | 1970-01-02 | 1970-01-02 | -123.12   | abc     | 6f6e65     |",
+        "+------------+------------+---------+-----+------------+------------+------------+-----------+---------+------------+",
     ];
 
     let results = execute_to_batches(&ctx, sql).await;
     assert_batches_sorted_eq!(expected, &results);

Review Comment:
   This test is failing on CI: https://github.com/apache/arrow-datafusion/runs/6953453815?check_suite_focus=true
   
   Looking at the diff it appears to be related to column c5: 
   ![Screen Shot 2022-06-19 at 7 45 59 AM](https://user-images.githubusercontent.com/490673/174479269-31385d43-a730-4a21-a10e-ccbfbf42b0a0.png)
   
   And thus it seems unrelated to this PR (though a real bug). Thus  I personally suggest you change the query to not select `c5` and we then file ticket to chase down the real problem. 



##########
datafusion/core/src/physical_plan/hash_join.rs:
##########
@@ -947,6 +951,27 @@ macro_rules! equal_rows_elem {
     }};
 }
 
+macro_rules! equal_rows_elem_with_string_dict {
+    ($key_array_type:ident, $l: ident, $r: ident, $left: ident, $right: ident, $null_equals_null: ident) => {{

Review Comment:
   For dictionaries, I think `$left` and `$right` are actually indexes into the *keys* array, and then the keys array contains the corresponding index into `values`.
   
   ```text
   ┌ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─                                                
     ┌─────────────────┐  ┌─────────┐ │     ┌─────────────────┐                       
   │ │        A        │  │    0    │       │        A        │     values[keys[0]]   
     ├─────────────────┤  ├─────────┤ │     ├─────────────────┤                       
   │ │        D        │  │    2    │       │        B        │     values[keys[1]]   
     ├─────────────────┤  ├─────────┤ │     ├─────────────────┤                       
   │ │        B        │  │    2    │       │        B        │     values[keys[2]]   
     ├─────────────────┤  ├─────────┤ │     ├─────────────────┤                       
   │ │        C        │  │    1    │       │        D        │     values[keys[3]]   
     ├─────────────────┤  └─────────┘ │     └─────────────────┘                       
   │ │        E        │     keys                                                     
     └─────────────────┘              │                                               
   │       values                             Logical array                           
    ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ┘          Contents                             
                                                                                      
             DictionaryArray                                                          
                length = 4                                                            
                                                                                      
   ```
   
   In other words, I think you need to compare the values using something like:
   https://github.com/AssHero/arrow-datafusion/blob/hashjoin/datafusion/common/src/scalar.rs#L338-L361



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscribe@arrow.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org