You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by ag...@apache.org on 2023/01/27 01:38:47 UTC

[arrow-datafusion] branch master updated: Add null-equals-null join support (#5085)

This is an automated email from the ASF dual-hosted git repository.

agrove pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/master by this push:
     new ad3cc242a Add null-equals-null join support (#5085)
ad3cc242a is described below

commit ad3cc242ad55bd83a35ba3b4009b63a4bbc8c24a
Author: Nuttiiya Seekhao <37...@users.noreply.github.com>
AuthorDate: Thu Jan 26 20:38:40 2023 -0500

    Add null-equals-null join support (#5085)
---
 datafusion/substrait/Cargo.toml         |  1 +
 datafusion/substrait/src/consumer.rs    | 45 +++++++++++++++++++++------------
 datafusion/substrait/src/producer.rs    | 21 +++++++++------
 datafusion/substrait/tests/roundtrip.rs | 16 ++++++++++++
 4 files changed, 59 insertions(+), 24 deletions(-)

diff --git a/datafusion/substrait/Cargo.toml b/datafusion/substrait/Cargo.toml
index 5e9beef43..a007c8312 100644
--- a/datafusion/substrait/Cargo.toml
+++ b/datafusion/substrait/Cargo.toml
@@ -24,6 +24,7 @@ rust-version = "1.62"
 [dependencies]
 async-recursion = "1.0"
 datafusion = { version = "17.0.0", path = "../core" }
+itertools = "0.10.5"
 prost = "0.11"
 prost-types = "0.11"
 substrait = "0.4"
diff --git a/datafusion/substrait/src/consumer.rs b/datafusion/substrait/src/consumer.rs
index e7303e21c..52f1f052a 100644
--- a/datafusion/substrait/src/consumer.rs
+++ b/datafusion/substrait/src/consumer.rs
@@ -325,29 +325,42 @@ pub async fn from_substrait_rel(
             )
             .await?;
             let predicates = split_conjunction(&on);
-            let pairs = predicates
+            // TODO: collect only one null_eq_null
+            let join_exprs: Vec<(Column, Column, bool)> = predicates
                 .iter()
                 .map(|p| match p {
-                    Expr::BinaryExpr(BinaryExpr {
-                        left,
-                        op: Operator::Eq,
-                        right,
-                    }) => match (left.as_ref(), right.as_ref()) {
-                        (Expr::Column(l), Expr::Column(r)) => {
-                            Ok((l.flat_name(), r.flat_name()))
+                    Expr::BinaryExpr(BinaryExpr { left, op, right }) => {
+                        match (left.as_ref(), right.as_ref()) {
+                            (Expr::Column(l), Expr::Column(r)) => match op {
+                                Operator::Eq => Ok((l.clone(), r.clone(), false)),
+                                Operator::IsNotDistinctFrom => {
+                                    Ok((l.clone(), r.clone(), true))
+                                }
+                                _ => Err(DataFusionError::Internal(
+                                    "invalid join condition op".to_string(),
+                                )),
+                            },
+                            _ => Err(DataFusionError::Internal(
+                                "invalid join condition expresssion".to_string(),
+                            )),
                         }
-                        _ => Err(DataFusionError::Internal(
-                            "invalid join condition".to_string(),
-                        )),
-                    },
+                    }
                     _ => Err(DataFusionError::Internal(
-                        "invalid join condition".to_string(),
+                        "Non-binary expression is not supported in join condition"
+                            .to_string(),
                     )),
                 })
                 .collect::<Result<Vec<_>>>()?;
-            let (left_cols, right_cols): (Vec<_>, Vec<_>) = pairs.iter().cloned().unzip();
-            left.join(right.build()?, join_type, (left_cols, right_cols), None)?
-                .build()
+            let (left_cols, right_cols, null_eq_nulls): (Vec<_>, Vec<_>, Vec<_>) =
+                itertools::multiunzip(join_exprs);
+            left.join_detailed(
+                right.build()?,
+                join_type,
+                (left_cols, right_cols),
+                None,
+                null_eq_nulls[0],
+            )?
+            .build()
         }
         Some(RelType::Read(read)) => match &read.as_ref().read_type {
             Some(ReadType::NamedTable(nt)) => {
diff --git a/datafusion/substrait/src/producer.rs b/datafusion/substrait/src/producer.rs
index 3a9c4923b..ef8a52d9d 100644
--- a/datafusion/substrait/src/producer.rs
+++ b/datafusion/substrait/src/producer.rs
@@ -15,7 +15,7 @@
 // specific language governing permissions and limitations
 // under the License.
 
-use std::collections::HashMap;
+use std::{collections::HashMap, sync::Arc};
 
 use datafusion::{
     error::{DataFusionError, Result},
@@ -246,11 +246,6 @@ pub fn to_substrait_rel(
             let right = to_substrait_rel(join.right.as_ref(), extension_info)?;
             let join_type = to_substrait_jointype(join.join_type);
             // we only support basic joins so return an error for anything not yet supported
-            if join.null_equals_null {
-                return Err(DataFusionError::NotImplemented(
-                    "join null_equals_null".to_string(),
-                ));
-            }
             if join.filter.is_some() {
                 return Err(DataFusionError::NotImplemented("join filter".to_string()));
             }
@@ -264,11 +259,20 @@ pub fn to_substrait_rel(
             }
             // map the left and right columns to binary expressions in the form `l = r`
             // build a single expression for the ON condition, such as `l.a = r.a AND l.b = r.b`
+            let eq_op = if join.null_equals_null {
+                Operator::IsNotDistinctFrom
+            } else {
+                Operator::Eq
+            };
             let join_expression = join
                 .on
                 .iter()
-                .map(|(l, r)| binary_expr(l.clone(), Operator::Eq, r.clone()))
+                .map(|(l, r)| binary_expr(l.clone(), eq_op, r.clone()))
                 .reduce(|acc: Expr, expr: Expr| acc.and(expr));
+            // join schema from left and right to maintain all nececesary columns from inputs
+            // note that we cannot simple use join.schema here since we discard some input columns
+            // when performing semi and anti joins
+            let join_schema = join.left.schema().join(join.right.schema());
             if let Some(e) = join_expression {
                 Ok(Box::new(Rel {
                     rel_type: Some(RelType::Join(Box::new(JoinRel {
@@ -278,7 +282,7 @@ pub fn to_substrait_rel(
                         r#type: join_type as i32,
                         expression: Some(Box::new(to_substrait_rex(
                             &e,
-                            &join.schema,
+                            &Arc::new(join_schema?),
                             extension_info,
                         )?)),
                         post_join_filter: None,
@@ -579,6 +583,7 @@ pub fn to_substrait_rex(
                 ScalarValue::Int16(Some(n)) => Some(LiteralType::I16(*n as i32)),
                 ScalarValue::Int32(Some(n)) => Some(LiteralType::I32(*n)),
                 ScalarValue::Int64(Some(n)) => Some(LiteralType::I64(*n)),
+                ScalarValue::UInt8(Some(n)) => Some(LiteralType::I16(*n as i32)), // Substrait currently does not support unsigned integer
                 ScalarValue::Boolean(Some(b)) => Some(LiteralType::Boolean(*b)),
                 ScalarValue::Float32(Some(f)) => Some(LiteralType::Fp32(*f)),
                 ScalarValue::Float64(Some(f)) => Some(LiteralType::Fp64(*f)),
diff --git a/datafusion/substrait/tests/roundtrip.rs b/datafusion/substrait/tests/roundtrip.rs
index 0ddde1754..f9bb53bd0 100644
--- a/datafusion/substrait/tests/roundtrip.rs
+++ b/datafusion/substrait/tests/roundtrip.rs
@@ -219,12 +219,28 @@ mod tests {
             .await
     }
 
+    #[tokio::test]
+    async fn simple_intersect() -> Result<()> {
+        assert_expected_plan(
+            "SELECT COUNT(*) FROM (SELECT data.a FROM data INTERSECT SELECT data2.a FROM data2);",
+            "Projection: COUNT(Int16(1))\
+            \n  Aggregate: groupBy=[[]], aggr=[[COUNT(Int16(1))]]\
+            \n    LeftSemi Join: data.a = data2.a\
+            \n      Aggregate: groupBy=[[data.a]], aggr=[[]]\
+            \n        TableScan: data projection=[a]\
+            \n      Projection: data2.a\
+            \n        TableScan: data2 projection=[a]",
+        )
+        .await
+    }
+
     async fn assert_expected_plan(sql: &str, expected_plan_str: &str) -> Result<()> {
         let mut ctx = create_context().await?;
         let df = ctx.sql(sql).await?;
         let plan = df.into_optimized_plan()?;
         let proto = to_substrait_plan(&plan)?;
         let plan2 = from_substrait_plan(&mut ctx, &proto).await?;
+        let plan2 = ctx.state().optimize(&plan2)?;
         let plan2str = format!("{plan2:?}");
         assert_eq!(expected_plan_str, &plan2str);
         Ok(())