You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by ag...@apache.org on 2023/01/27 01:38:47 UTC
[arrow-datafusion] branch master updated: Add null-equals-null join support (#5085)
This is an automated email from the ASF dual-hosted git repository.
agrove pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/master by this push:
new ad3cc242a Add null-equals-null join support (#5085)
ad3cc242a is described below
commit ad3cc242ad55bd83a35ba3b4009b63a4bbc8c24a
Author: Nuttiiya Seekhao <37...@users.noreply.github.com>
AuthorDate: Thu Jan 26 20:38:40 2023 -0500
Add null-equals-null join support (#5085)
---
datafusion/substrait/Cargo.toml | 1 +
datafusion/substrait/src/consumer.rs | 45 +++++++++++++++++++++------------
datafusion/substrait/src/producer.rs | 21 +++++++++------
datafusion/substrait/tests/roundtrip.rs | 16 ++++++++++++
4 files changed, 59 insertions(+), 24 deletions(-)
diff --git a/datafusion/substrait/Cargo.toml b/datafusion/substrait/Cargo.toml
index 5e9beef43..a007c8312 100644
--- a/datafusion/substrait/Cargo.toml
+++ b/datafusion/substrait/Cargo.toml
@@ -24,6 +24,7 @@ rust-version = "1.62"
[dependencies]
async-recursion = "1.0"
datafusion = { version = "17.0.0", path = "../core" }
+itertools = "0.10.5"
prost = "0.11"
prost-types = "0.11"
substrait = "0.4"
diff --git a/datafusion/substrait/src/consumer.rs b/datafusion/substrait/src/consumer.rs
index e7303e21c..52f1f052a 100644
--- a/datafusion/substrait/src/consumer.rs
+++ b/datafusion/substrait/src/consumer.rs
@@ -325,29 +325,42 @@ pub async fn from_substrait_rel(
)
.await?;
let predicates = split_conjunction(&on);
- let pairs = predicates
+ // TODO: collect only one null_eq_null
+ let join_exprs: Vec<(Column, Column, bool)> = predicates
.iter()
.map(|p| match p {
- Expr::BinaryExpr(BinaryExpr {
- left,
- op: Operator::Eq,
- right,
- }) => match (left.as_ref(), right.as_ref()) {
- (Expr::Column(l), Expr::Column(r)) => {
- Ok((l.flat_name(), r.flat_name()))
+ Expr::BinaryExpr(BinaryExpr { left, op, right }) => {
+ match (left.as_ref(), right.as_ref()) {
+ (Expr::Column(l), Expr::Column(r)) => match op {
+ Operator::Eq => Ok((l.clone(), r.clone(), false)),
+ Operator::IsNotDistinctFrom => {
+ Ok((l.clone(), r.clone(), true))
+ }
+ _ => Err(DataFusionError::Internal(
+ "invalid join condition op".to_string(),
+ )),
+ },
+ _ => Err(DataFusionError::Internal(
+ "invalid join condition expresssion".to_string(),
+ )),
}
- _ => Err(DataFusionError::Internal(
- "invalid join condition".to_string(),
- )),
- },
+ }
_ => Err(DataFusionError::Internal(
- "invalid join condition".to_string(),
+ "Non-binary expression is not supported in join condition"
+ .to_string(),
)),
})
.collect::<Result<Vec<_>>>()?;
- let (left_cols, right_cols): (Vec<_>, Vec<_>) = pairs.iter().cloned().unzip();
- left.join(right.build()?, join_type, (left_cols, right_cols), None)?
- .build()
+ let (left_cols, right_cols, null_eq_nulls): (Vec<_>, Vec<_>, Vec<_>) =
+ itertools::multiunzip(join_exprs);
+ left.join_detailed(
+ right.build()?,
+ join_type,
+ (left_cols, right_cols),
+ None,
+ null_eq_nulls[0],
+ )?
+ .build()
}
Some(RelType::Read(read)) => match &read.as_ref().read_type {
Some(ReadType::NamedTable(nt)) => {
diff --git a/datafusion/substrait/src/producer.rs b/datafusion/substrait/src/producer.rs
index 3a9c4923b..ef8a52d9d 100644
--- a/datafusion/substrait/src/producer.rs
+++ b/datafusion/substrait/src/producer.rs
@@ -15,7 +15,7 @@
// specific language governing permissions and limitations
// under the License.
-use std::collections::HashMap;
+use std::{collections::HashMap, sync::Arc};
use datafusion::{
error::{DataFusionError, Result},
@@ -246,11 +246,6 @@ pub fn to_substrait_rel(
let right = to_substrait_rel(join.right.as_ref(), extension_info)?;
let join_type = to_substrait_jointype(join.join_type);
// we only support basic joins so return an error for anything not yet supported
- if join.null_equals_null {
- return Err(DataFusionError::NotImplemented(
- "join null_equals_null".to_string(),
- ));
- }
if join.filter.is_some() {
return Err(DataFusionError::NotImplemented("join filter".to_string()));
}
@@ -264,11 +259,20 @@ pub fn to_substrait_rel(
}
// map the left and right columns to binary expressions in the form `l = r`
// build a single expression for the ON condition, such as `l.a = r.a AND l.b = r.b`
+ let eq_op = if join.null_equals_null {
+ Operator::IsNotDistinctFrom
+ } else {
+ Operator::Eq
+ };
let join_expression = join
.on
.iter()
- .map(|(l, r)| binary_expr(l.clone(), Operator::Eq, r.clone()))
+ .map(|(l, r)| binary_expr(l.clone(), eq_op, r.clone()))
.reduce(|acc: Expr, expr: Expr| acc.and(expr));
+ // join schema from left and right to maintain all nececesary columns from inputs
+ // note that we cannot simple use join.schema here since we discard some input columns
+ // when performing semi and anti joins
+ let join_schema = join.left.schema().join(join.right.schema());
if let Some(e) = join_expression {
Ok(Box::new(Rel {
rel_type: Some(RelType::Join(Box::new(JoinRel {
@@ -278,7 +282,7 @@ pub fn to_substrait_rel(
r#type: join_type as i32,
expression: Some(Box::new(to_substrait_rex(
&e,
- &join.schema,
+ &Arc::new(join_schema?),
extension_info,
)?)),
post_join_filter: None,
@@ -579,6 +583,7 @@ pub fn to_substrait_rex(
ScalarValue::Int16(Some(n)) => Some(LiteralType::I16(*n as i32)),
ScalarValue::Int32(Some(n)) => Some(LiteralType::I32(*n)),
ScalarValue::Int64(Some(n)) => Some(LiteralType::I64(*n)),
+ ScalarValue::UInt8(Some(n)) => Some(LiteralType::I16(*n as i32)), // Substrait currently does not support unsigned integer
ScalarValue::Boolean(Some(b)) => Some(LiteralType::Boolean(*b)),
ScalarValue::Float32(Some(f)) => Some(LiteralType::Fp32(*f)),
ScalarValue::Float64(Some(f)) => Some(LiteralType::Fp64(*f)),
diff --git a/datafusion/substrait/tests/roundtrip.rs b/datafusion/substrait/tests/roundtrip.rs
index 0ddde1754..f9bb53bd0 100644
--- a/datafusion/substrait/tests/roundtrip.rs
+++ b/datafusion/substrait/tests/roundtrip.rs
@@ -219,12 +219,28 @@ mod tests {
.await
}
+ #[tokio::test]
+ async fn simple_intersect() -> Result<()> {
+ assert_expected_plan(
+ "SELECT COUNT(*) FROM (SELECT data.a FROM data INTERSECT SELECT data2.a FROM data2);",
+ "Projection: COUNT(Int16(1))\
+ \n Aggregate: groupBy=[[]], aggr=[[COUNT(Int16(1))]]\
+ \n LeftSemi Join: data.a = data2.a\
+ \n Aggregate: groupBy=[[data.a]], aggr=[[]]\
+ \n TableScan: data projection=[a]\
+ \n Projection: data2.a\
+ \n TableScan: data2 projection=[a]",
+ )
+ .await
+ }
+
async fn assert_expected_plan(sql: &str, expected_plan_str: &str) -> Result<()> {
let mut ctx = create_context().await?;
let df = ctx.sql(sql).await?;
let plan = df.into_optimized_plan()?;
let proto = to_substrait_plan(&plan)?;
let plan2 = from_substrait_plan(&mut ctx, &proto).await?;
+ let plan2 = ctx.state().optimize(&plan2)?;
let plan2str = format!("{plan2:?}");
assert_eq!(expected_plan_str, &plan2str);
Ok(())