You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by al...@apache.org on 2022/11/22 22:17:56 UTC

[arrow-datafusion] branch master updated: fix bug: right semi join will ignore filter with left side, and add test for test (#4327)

This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/master by this push:
     new 92325bf16 fix bug: right semi join will ignore filter with left side, and add test for test (#4327)
92325bf16 is described below

commit 92325bf167c0d215564695717cb56d01448887c8
Author: Kun Liu <li...@apache.org>
AuthorDate: Wed Nov 23 06:17:52 2022 +0800

    fix bug: right semi join will ignore filter with left side, and add test for test (#4327)
---
 .../core/src/physical_plan/joins/hash_join.rs      | 127 ++++++++++++++++++++-
 1 file changed, 123 insertions(+), 4 deletions(-)

diff --git a/datafusion/core/src/physical_plan/joins/hash_join.rs b/datafusion/core/src/physical_plan/joins/hash_join.rs
index d557ca6ed..ead84e66c 100644
--- a/datafusion/core/src/physical_plan/joins/hash_join.rs
+++ b/datafusion/core/src/physical_plan/joins/hash_join.rs
@@ -852,6 +852,7 @@ fn build_join_indexes(
                             &keys_values,
                             *null_equals_null,
                         )? {
+                            left_indices.append(i);
                             right_indices.append(row as u32);
                             break;
                         }
@@ -1611,6 +1612,8 @@ mod tests {
 
     use super::*;
     use crate::prelude::SessionContext;
+    use datafusion_common::ScalarValue;
+    use datafusion_physical_expr::expressions::Literal;
     use std::sync::Arc;
 
     fn build_table(
@@ -2289,7 +2292,7 @@ mod tests {
     }
 
     #[tokio::test]
-    async fn join_semi() -> Result<()> {
+    async fn join_left_semi() -> Result<()> {
         let session_ctx = SessionContext::new();
         let task_ctx = session_ctx.task_ctx();
         let left = build_table(
@@ -2329,6 +2332,63 @@ mod tests {
         Ok(())
     }
 
+    #[tokio::test]
+    async fn join_left_semi_with_filter() -> Result<()> {
+        let session_ctx = SessionContext::new();
+        let task_ctx = session_ctx.task_ctx();
+        let left = build_table(
+            ("a1", &vec![1, 2, 2, 3]),
+            ("b1", &vec![4, 5, 5, 7]), // 7 does not exist on the right
+            ("c1", &vec![7, 8, 8, 9]),
+        );
+        let right = build_table(
+            ("a2", &vec![10, 20, 30, 40]),
+            ("b1", &vec![4, 5, 6, 5]), // 5 is double on the right
+            ("c2", &vec![70, 80, 90, 100]),
+        );
+        let on = vec![(
+            Column::new_with_schema("b1", &left.schema())?,
+            Column::new_with_schema("b1", &right.schema())?,
+        )];
+
+        // build filter right.b2 > 4
+        let column_indices = vec![ColumnIndex {
+            index: 1,
+            side: JoinSide::Right,
+        }];
+        let intermediate_schema =
+            Schema::new(vec![Field::new("x", DataType::Int32, true)]);
+
+        let filter_expression = Arc::new(BinaryExpr::new(
+            Arc::new(Column::new("x", 0)),
+            Operator::Gt,
+            Arc::new(Literal::new(ScalarValue::Int32(Some(4)))),
+        )) as Arc<dyn PhysicalExpr>;
+
+        let filter =
+            JoinFilter::new(filter_expression, column_indices, intermediate_schema);
+
+        let join = join_with_filter(left, right, on, filter, &JoinType::LeftSemi, false)?;
+
+        let columns = columns(&join.schema());
+        assert_eq!(columns, vec!["a1", "b1", "c1"]);
+
+        let stream = join.execute(0, task_ctx)?;
+        let batches = common::collect(stream).await?;
+
+        let expected = vec![
+            "+----+----+----+",
+            "| a1 | b1 | c1 |",
+            "+----+----+----+",
+            "| 2  | 5  | 8  |",
+            "| 2  | 5  | 8  |",
+            "+----+----+----+",
+        ];
+        assert_batches_sorted_eq!(expected, &batches);
+
+        Ok(())
+    }
+
     #[tokio::test]
     async fn join_right_semi() -> Result<()> {
         let session_ctx = SessionContext::new();
@@ -2372,7 +2432,66 @@ mod tests {
     }
 
     #[tokio::test]
-    async fn join_anti() -> Result<()> {
+    async fn join_right_semi_with_filter() -> Result<()> {
+        let session_ctx = SessionContext::new();
+        let task_ctx = session_ctx.task_ctx();
+        let left = build_table(
+            ("a2", &vec![10, 20, 30, 40]),
+            ("b2", &vec![4, 5, 6, 5]), // 5 is double on the left
+            ("c2", &vec![70, 80, 90, 100]),
+        );
+        let right = build_table(
+            ("a1", &vec![1, 2, 2, 3]),
+            ("b1", &vec![4, 5, 5, 7]), // 7 does not exist on the left
+            ("c1", &vec![7, 8, 8, 9]),
+        );
+
+        let on = vec![(
+            Column::new_with_schema("b2", &left.schema())?,
+            Column::new_with_schema("b1", &right.schema())?,
+        )];
+
+        // build filter left.b2 > 4
+        let column_indices = vec![ColumnIndex {
+            index: 1,
+            side: JoinSide::Left,
+        }];
+        let intermediate_schema =
+            Schema::new(vec![Field::new("x", DataType::Int32, true)]);
+
+        let filter_expression = Arc::new(BinaryExpr::new(
+            Arc::new(Column::new("x", 0)),
+            Operator::Gt,
+            Arc::new(Literal::new(ScalarValue::Int32(Some(4)))),
+        )) as Arc<dyn PhysicalExpr>;
+
+        let filter =
+            JoinFilter::new(filter_expression, column_indices, intermediate_schema);
+
+        let join =
+            join_with_filter(left, right, on, filter, &JoinType::RightSemi, false)?;
+
+        let columns = columns(&join.schema());
+        assert_eq!(columns, vec!["a1", "b1", "c1"]);
+
+        let stream = join.execute(0, task_ctx)?;
+        let batches = common::collect(stream).await?;
+
+        let expected = vec![
+            "+----+----+----+",
+            "| a1 | b1 | c1 |",
+            "+----+----+----+",
+            "| 2  | 5  | 8  |",
+            "| 2  | 5  | 8  |",
+            "+----+----+----+",
+        ];
+        assert_batches_sorted_eq!(expected, &batches);
+
+        Ok(())
+    }
+
+    #[tokio::test]
+    async fn join_left_anti() -> Result<()> {
         let session_ctx = SessionContext::new();
         let task_ctx = session_ctx.task_ctx();
         let left = build_table(
@@ -2450,7 +2569,7 @@ mod tests {
     }
 
     #[tokio::test]
-    async fn join_anti_with_filter() -> Result<()> {
+    async fn join_left_anti_with_filter() -> Result<()> {
         let session_ctx = SessionContext::new();
         let task_ctx = session_ctx.task_ctx();
         let left = build_table(
@@ -2466,7 +2585,7 @@ mod tests {
             Column::new_with_schema("col1", &right.schema())?,
         )];
 
-        // build filter b.col2 <> a.col2
+        // build filter left.col2 <> right.col2
         let column_indices = vec![
             ColumnIndex {
                 index: 1,