You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by al...@apache.org on 2021/11/06 11:05:09 UTC
[arrow-datafusion] branch master updated: fix comments (#1135)
This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/master by this push:
new 2b24b89 fix comments (#1135)
2b24b89 is described below
commit 2b24b892058d2dc87a29e931c1e81d621a60942e
Author: Carlos <wx...@gmail.com>
AuthorDate: Sat Nov 6 19:05:00 2021 +0800
fix comments (#1135)
---
ballista/rust/core/proto/ballista.proto | 2 +
.../rust/core/src/serde/logical_plan/to_proto.rs | 2 +
.../core/src/serde/physical_plan/from_proto.rs | 1 +
ballista/rust/core/src/serde/physical_plan/mod.rs | 1 +
.../rust/core/src/serde/physical_plan/to_proto.rs | 1 +
datafusion/src/logical_plan/builder.rs | 14 ++
datafusion/src/logical_plan/plan.rs | 2 +
datafusion/src/optimizer/projection_push_down.rs | 2 +
datafusion/src/optimizer/utils.rs | 2 +
.../physical_optimizer/hash_build_probe_order.rs | 3 +
datafusion/src/physical_plan/hash_join.rs | 212 ++++++++++++++++-----
datafusion/src/physical_plan/planner.rs | 3 +
datafusion/src/sql/planner.rs | 26 ++-
datafusion/tests/sql.rs | 65 +++++++
14 files changed, 282 insertions(+), 54 deletions(-)
diff --git a/ballista/rust/core/proto/ballista.proto b/ballista/rust/core/proto/ballista.proto
index 62b3185..1815811 100644
--- a/ballista/rust/core/proto/ballista.proto
+++ b/ballista/rust/core/proto/ballista.proto
@@ -435,6 +435,7 @@ message JoinNode {
JoinConstraint join_constraint = 4;
repeated Column left_join_column = 5;
repeated Column right_join_column = 6;
+ bool null_equals_null = 7;
}
message CrossJoinNode {
@@ -649,6 +650,7 @@ message HashJoinExecNode {
repeated JoinOn on = 3;
JoinType join_type = 4;
PartitionMode partition_mode = 6;
+ bool null_equals_null = 7;
}
message CrossJoinExecNode {
diff --git a/ballista/rust/core/src/serde/logical_plan/to_proto.rs b/ballista/rust/core/src/serde/logical_plan/to_proto.rs
index 1d1d48e..80af698 100644
--- a/ballista/rust/core/src/serde/logical_plan/to_proto.rs
+++ b/ballista/rust/core/src/serde/logical_plan/to_proto.rs
@@ -850,6 +850,7 @@ impl TryInto<protobuf::LogicalPlanNode> for &LogicalPlan {
on,
join_type,
join_constraint,
+ null_equals_null,
..
} => {
let left: protobuf::LogicalPlanNode = left.as_ref().try_into()?;
@@ -868,6 +869,7 @@ impl TryInto<protobuf::LogicalPlanNode> for &LogicalPlan {
join_constraint: join_constraint.into(),
left_join_column,
right_join_column,
+ null_equals_null: *null_equals_null,
},
))),
})
diff --git a/ballista/rust/core/src/serde/physical_plan/from_proto.rs b/ballista/rust/core/src/serde/physical_plan/from_proto.rs
index 99d2de0..3c05957 100644
--- a/ballista/rust/core/src/serde/physical_plan/from_proto.rs
+++ b/ballista/rust/core/src/serde/physical_plan/from_proto.rs
@@ -365,6 +365,7 @@ impl TryInto<Arc<dyn ExecutionPlan>> for &protobuf::PhysicalPlanNode {
on,
&join_type.into(),
partition_mode,
+ &hashjoin.null_equals_null,
)?))
}
PhysicalPlanType::CrossJoin(crossjoin) => {
diff --git a/ballista/rust/core/src/serde/physical_plan/mod.rs b/ballista/rust/core/src/serde/physical_plan/mod.rs
index 4bf013a..aca8f64 100644
--- a/ballista/rust/core/src/serde/physical_plan/mod.rs
+++ b/ballista/rust/core/src/serde/physical_plan/mod.rs
@@ -107,6 +107,7 @@ mod roundtrip_tests {
on.clone(),
join_type,
*partition_mode,
+ &false,
)?))?;
}
}
diff --git a/ballista/rust/core/src/serde/physical_plan/to_proto.rs b/ballista/rust/core/src/serde/physical_plan/to_proto.rs
index afbb02a..41484db 100644
--- a/ballista/rust/core/src/serde/physical_plan/to_proto.rs
+++ b/ballista/rust/core/src/serde/physical_plan/to_proto.rs
@@ -162,6 +162,7 @@ impl TryInto<protobuf::PhysicalPlanNode> for Arc<dyn ExecutionPlan> {
on,
join_type: join_type.into(),
partition_mode: partition_mode.into(),
+ null_equals_null: *exec.null_equals_null(),
},
))),
})
diff --git a/datafusion/src/logical_plan/builder.rs b/datafusion/src/logical_plan/builder.rs
index a9d814f..ac8e0c3 100644
--- a/datafusion/src/logical_plan/builder.rs
+++ b/datafusion/src/logical_plan/builder.rs
@@ -493,6 +493,18 @@ impl LogicalPlanBuilder {
join_type: JoinType,
join_keys: (Vec<impl Into<Column>>, Vec<impl Into<Column>>),
) -> Result<Self> {
+ self.join_detailed(right, join_type, join_keys, false)
+ }
+
+ /// Apply a join with on constraint and specified null equality
+ /// If null_equals_null is true then null == null, else null != null
+ pub fn join_detailed(
+ &self,
+ right: &LogicalPlan,
+ join_type: JoinType,
+ join_keys: (Vec<impl Into<Column>>, Vec<impl Into<Column>>),
+ null_equals_null: bool,
+ ) -> Result<Self> {
if join_keys.0.len() != join_keys.1.len() {
return Err(DataFusionError::Plan(
"left_keys and right_keys were not the same length".to_string(),
@@ -580,6 +592,7 @@ impl LogicalPlanBuilder {
join_type,
join_constraint: JoinConstraint::On,
schema: DFSchemaRef::new(join_schema),
+ null_equals_null,
}))
}
@@ -611,6 +624,7 @@ impl LogicalPlanBuilder {
join_type,
join_constraint: JoinConstraint::Using,
schema: DFSchemaRef::new(join_schema),
+ null_equals_null: false,
}))
}
diff --git a/datafusion/src/logical_plan/plan.rs b/datafusion/src/logical_plan/plan.rs
index 13921d5..f53d8e6 100644
--- a/datafusion/src/logical_plan/plan.rs
+++ b/datafusion/src/logical_plan/plan.rs
@@ -135,6 +135,8 @@ pub enum LogicalPlan {
join_constraint: JoinConstraint,
/// The output schema, containing fields from the left and right inputs
schema: DFSchemaRef,
+ /// If null_equals_null is true, null == null else null != null
+ null_equals_null: bool,
},
/// Apply Cross Join to two logical plans
CrossJoin {
diff --git a/datafusion/src/optimizer/projection_push_down.rs b/datafusion/src/optimizer/projection_push_down.rs
index 4fabc4f..78ea278 100644
--- a/datafusion/src/optimizer/projection_push_down.rs
+++ b/datafusion/src/optimizer/projection_push_down.rs
@@ -195,6 +195,7 @@ fn optimize_plan(
on,
join_type,
join_constraint,
+ null_equals_null,
..
} => {
for (l, r) in on {
@@ -231,6 +232,7 @@ fn optimize_plan(
join_constraint: *join_constraint,
on: on.clone(),
schema: DFSchemaRef::new(schema),
+ null_equals_null: *null_equals_null,
})
}
LogicalPlan::Window {
diff --git a/datafusion/src/optimizer/utils.rs b/datafusion/src/optimizer/utils.rs
index f36330e..8ce0d3a 100644
--- a/datafusion/src/optimizer/utils.rs
+++ b/datafusion/src/optimizer/utils.rs
@@ -198,6 +198,7 @@ pub fn from_plan(
join_type,
join_constraint,
on,
+ null_equals_null,
..
} => {
let schema =
@@ -209,6 +210,7 @@ pub fn from_plan(
join_constraint: *join_constraint,
on: on.clone(),
schema: DFSchemaRef::new(schema),
+ null_equals_null: *null_equals_null,
})
}
LogicalPlan::CrossJoin { .. } => {
diff --git a/datafusion/src/physical_optimizer/hash_build_probe_order.rs b/datafusion/src/physical_optimizer/hash_build_probe_order.rs
index 0b87ceb..0d1c39f 100644
--- a/datafusion/src/physical_optimizer/hash_build_probe_order.rs
+++ b/datafusion/src/physical_optimizer/hash_build_probe_order.rs
@@ -123,6 +123,7 @@ impl PhysicalOptimizerRule for HashBuildProbeOrder {
.collect(),
&swap_join_type(*hash_join.join_type()),
*hash_join.partition_mode(),
+ hash_join.null_equals_null(),
)?;
let proj = ProjectionExec::try_new(
swap_reverting_projection(&*left.schema(), &*right.schema()),
@@ -195,6 +196,7 @@ mod tests {
)],
&JoinType::Left,
PartitionMode::CollectLeft,
+ &false,
)
.unwrap();
@@ -238,6 +240,7 @@ mod tests {
)],
&JoinType::Left,
PartitionMode::CollectLeft,
+ &false,
)
.unwrap();
diff --git a/datafusion/src/physical_plan/hash_join.rs b/datafusion/src/physical_plan/hash_join.rs
index 2ed0faa..727d1c6 100644
--- a/datafusion/src/physical_plan/hash_join.rs
+++ b/datafusion/src/physical_plan/hash_join.rs
@@ -120,6 +120,8 @@ pub struct HashJoinExec {
metrics: ExecutionPlanMetricsSet,
/// Information of index and left / right placement of columns
column_indices: Vec<ColumnIndex>,
+ /// If null_equals_null is true, null == null else null != null
+ null_equals_null: bool,
}
/// Metrics for HashJoinExec
@@ -180,6 +182,7 @@ impl HashJoinExec {
on: JoinOn,
join_type: &JoinType,
partition_mode: PartitionMode,
+ null_equals_null: &bool,
) -> Result<Self> {
let left_schema = left.schema();
let right_schema = right.schema();
@@ -201,6 +204,7 @@ impl HashJoinExec {
mode: partition_mode,
metrics: ExecutionPlanMetricsSet::new(),
column_indices,
+ null_equals_null: *null_equals_null,
})
}
@@ -228,6 +232,11 @@ impl HashJoinExec {
pub fn partition_mode(&self) -> &PartitionMode {
&self.mode
}
+
+ /// Get null_equals_null
+ pub fn null_equals_null(&self) -> &bool {
+ &self.null_equals_null
+ }
}
#[async_trait]
@@ -255,6 +264,7 @@ impl ExecutionPlan for HashJoinExec {
self.on.clone(),
&self.join_type,
self.mode,
+ &self.null_equals_null,
)?)),
_ => Err(DataFusionError::Internal(
"HashJoinExec wrong number of children".to_string(),
@@ -406,6 +416,7 @@ impl ExecutionPlan for HashJoinExec {
self.random_state.clone(),
visited_left_side,
HashJoinMetrics::new(partition, &self.metrics),
+ self.null_equals_null,
)))
}
@@ -499,6 +510,8 @@ struct HashJoinStream {
join_metrics: HashJoinMetrics,
/// Information of index and left / right placement of columns
column_indices: Vec<ColumnIndex>,
+ /// If null_equals_null is true, null == null else null != null
+ null_equals_null: bool,
}
#[allow(clippy::too_many_arguments)]
@@ -514,6 +527,7 @@ impl HashJoinStream {
random_state: RandomState,
visited_left_side: Vec<bool>,
join_metrics: HashJoinMetrics,
+ null_equals_null: bool,
) -> Self {
HashJoinStream {
schema,
@@ -527,6 +541,7 @@ impl HashJoinStream {
visited_left_side,
is_exhausted: false,
join_metrics,
+ null_equals_null,
}
}
}
@@ -581,10 +596,18 @@ fn build_batch(
schema: &Schema,
column_indices: &[ColumnIndex],
random_state: &RandomState,
+ null_equals_null: &bool,
) -> ArrowResult<(RecordBatch, UInt64Array)> {
- let (left_indices, right_indices) =
- build_join_indexes(left_data, batch, join_type, on_left, on_right, random_state)
- .unwrap();
+ let (left_indices, right_indices) = build_join_indexes(
+ left_data,
+ batch,
+ join_type,
+ on_left,
+ on_right,
+ random_state,
+ null_equals_null,
+ )
+ .unwrap();
if matches!(join_type, JoinType::Semi | JoinType::Anti) {
return Ok((
@@ -637,6 +660,7 @@ fn build_join_indexes(
left_on: &[Column],
right_on: &[Column],
random_state: &RandomState,
+ null_equals_null: &bool,
) -> Result<(UInt64Array, UInt32Array)> {
let keys_values = right_on
.iter()
@@ -668,7 +692,13 @@ fn build_join_indexes(
{
for &i in indices {
// Check hash collisions
- if equal_rows(i as usize, row, &left_join_values, &keys_values)? {
+ if equal_rows(
+ i as usize,
+ row,
+ &left_join_values,
+ &keys_values,
+ *null_equals_null,
+ )? {
left_indices.append(i);
right_indices.append(row as u32);
}
@@ -702,7 +732,13 @@ fn build_join_indexes(
{
for &i in indices {
// Collision check
- if equal_rows(i as usize, row, &left_join_values, &keys_values)? {
+ if equal_rows(
+ i as usize,
+ row,
+ &left_join_values,
+ &keys_values,
+ *null_equals_null,
+ )? {
left_indices.append_value(i)?;
right_indices.append_value(row as u32)?;
}
@@ -725,6 +761,7 @@ fn build_join_indexes(
row,
&left_join_values,
&keys_values,
+ *null_equals_null,
)? {
left_indices.append_value(i)?;
right_indices.append_value(row as u32)?;
@@ -751,12 +788,13 @@ fn build_join_indexes(
}
macro_rules! equal_rows_elem {
- ($array_type:ident, $l: ident, $r: ident, $left: ident, $right: ident) => {{
+ ($array_type:ident, $l: ident, $r: ident, $left: ident, $right: ident, $null_equals_null: ident) => {{
let left_array = $l.as_any().downcast_ref::<$array_type>().unwrap();
let right_array = $r.as_any().downcast_ref::<$array_type>().unwrap();
match (left_array.is_null($left), right_array.is_null($right)) {
(false, false) => left_array.value($left) == right_array.value($right),
+ (true, true) => $null_equals_null,
_ => false,
}
}};
@@ -768,6 +806,7 @@ fn equal_rows(
right: usize,
left_arrays: &[ArrayRef],
right_arrays: &[ArrayRef],
+ null_equals_null: bool,
) -> Result<bool> {
let mut err = None;
let res = left_arrays
@@ -775,33 +814,87 @@ fn equal_rows(
.zip(right_arrays)
.all(|(l, r)| match l.data_type() {
DataType::Null => true,
- DataType::Boolean => equal_rows_elem!(BooleanArray, l, r, left, right),
- DataType::Int8 => equal_rows_elem!(Int8Array, l, r, left, right),
- DataType::Int16 => equal_rows_elem!(Int16Array, l, r, left, right),
- DataType::Int32 => equal_rows_elem!(Int32Array, l, r, left, right),
- DataType::Int64 => equal_rows_elem!(Int64Array, l, r, left, right),
- DataType::UInt8 => equal_rows_elem!(UInt8Array, l, r, left, right),
- DataType::UInt16 => equal_rows_elem!(UInt16Array, l, r, left, right),
- DataType::UInt32 => equal_rows_elem!(UInt32Array, l, r, left, right),
- DataType::UInt64 => equal_rows_elem!(UInt64Array, l, r, left, right),
- DataType::Float32 => equal_rows_elem!(Float32Array, l, r, left, right),
- DataType::Float64 => equal_rows_elem!(Float64Array, l, r, left, right),
+ DataType::Boolean => {
+ equal_rows_elem!(BooleanArray, l, r, left, right, null_equals_null)
+ }
+ DataType::Int8 => {
+ equal_rows_elem!(Int8Array, l, r, left, right, null_equals_null)
+ }
+ DataType::Int16 => {
+ equal_rows_elem!(Int16Array, l, r, left, right, null_equals_null)
+ }
+ DataType::Int32 => {
+ equal_rows_elem!(Int32Array, l, r, left, right, null_equals_null)
+ }
+ DataType::Int64 => {
+ equal_rows_elem!(Int64Array, l, r, left, right, null_equals_null)
+ }
+ DataType::UInt8 => {
+ equal_rows_elem!(UInt8Array, l, r, left, right, null_equals_null)
+ }
+ DataType::UInt16 => {
+ equal_rows_elem!(UInt16Array, l, r, left, right, null_equals_null)
+ }
+ DataType::UInt32 => {
+ equal_rows_elem!(UInt32Array, l, r, left, right, null_equals_null)
+ }
+ DataType::UInt64 => {
+ equal_rows_elem!(UInt64Array, l, r, left, right, null_equals_null)
+ }
+ DataType::Float32 => {
+ equal_rows_elem!(Float32Array, l, r, left, right, null_equals_null)
+ }
+ DataType::Float64 => {
+ equal_rows_elem!(Float64Array, l, r, left, right, null_equals_null)
+ }
DataType::Timestamp(time_unit, None) => match time_unit {
TimeUnit::Second => {
- equal_rows_elem!(TimestampSecondArray, l, r, left, right)
+ equal_rows_elem!(
+ TimestampSecondArray,
+ l,
+ r,
+ left,
+ right,
+ null_equals_null
+ )
}
TimeUnit::Millisecond => {
- equal_rows_elem!(TimestampMillisecondArray, l, r, left, right)
+ equal_rows_elem!(
+ TimestampMillisecondArray,
+ l,
+ r,
+ left,
+ right,
+ null_equals_null
+ )
}
TimeUnit::Microsecond => {
- equal_rows_elem!(TimestampMicrosecondArray, l, r, left, right)
+ equal_rows_elem!(
+ TimestampMicrosecondArray,
+ l,
+ r,
+ left,
+ right,
+ null_equals_null
+ )
}
TimeUnit::Nanosecond => {
- equal_rows_elem!(TimestampNanosecondArray, l, r, left, right)
+ equal_rows_elem!(
+ TimestampNanosecondArray,
+ l,
+ r,
+ left,
+ right,
+ null_equals_null
+ )
}
},
- DataType::Utf8 => equal_rows_elem!(StringArray, l, r, left, right),
- DataType::LargeUtf8 => equal_rows_elem!(LargeStringArray, l, r, left, right),
+ DataType::Utf8 => {
+ equal_rows_elem!(StringArray, l, r, left, right, null_equals_null)
+ }
+ DataType::LargeUtf8 => {
+ equal_rows_elem!(LargeStringArray, l, r, left, right, null_equals_null)
+ }
_ => {
// This is internal because we should have caught this before.
err = Some(Err(DataFusionError::Internal(
@@ -883,6 +976,7 @@ impl Stream for HashJoinStream {
&self.schema,
&self.column_indices,
&self.random_state,
+ &self.null_equals_null,
);
self.join_metrics.input_batches.add(1);
self.join_metrics.input_rows.add(batch.num_rows());
@@ -976,8 +1070,16 @@ mod tests {
right: Arc<dyn ExecutionPlan>,
on: JoinOn,
join_type: &JoinType,
+ null_equals_null: bool,
) -> Result<HashJoinExec> {
- HashJoinExec::try_new(left, right, on, join_type, PartitionMode::CollectLeft)
+ HashJoinExec::try_new(
+ left,
+ right,
+ on,
+ join_type,
+ PartitionMode::CollectLeft,
+ &null_equals_null,
+ )
}
async fn join_collect(
@@ -985,8 +1087,9 @@ mod tests {
right: Arc<dyn ExecutionPlan>,
on: JoinOn,
join_type: &JoinType,
+ null_equals_null: bool,
) -> Result<(Vec<String>, Vec<RecordBatch>)> {
- let join = join(left, right, on, join_type)?;
+ let join = join(left, right, on, join_type, null_equals_null)?;
let columns = columns(&join.schema());
let stream = join.execute(0).await?;
@@ -1000,6 +1103,7 @@ mod tests {
right: Arc<dyn ExecutionPlan>,
on: JoinOn,
join_type: &JoinType,
+ null_equals_null: bool,
) -> Result<(Vec<String>, Vec<RecordBatch>)> {
let partition_count = 4;
@@ -1025,6 +1129,7 @@ mod tests {
on,
join_type,
PartitionMode::Partitioned,
+ &null_equals_null,
)?;
let columns = columns(&join.schema());
@@ -1062,9 +1167,14 @@ mod tests {
Column::new_with_schema("b1", &right.schema())?,
)];
- let (columns, batches) =
- join_collect(left.clone(), right.clone(), on.clone(), &JoinType::Inner)
- .await?;
+ let (columns, batches) = join_collect(
+ left.clone(),
+ right.clone(),
+ on.clone(),
+ &JoinType::Inner,
+ false,
+ )
+ .await?;
assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b1", "c2"]);
@@ -1104,6 +1214,7 @@ mod tests {
right.clone(),
on.clone(),
&JoinType::Inner,
+ false,
)
.await?;
@@ -1140,7 +1251,8 @@ mod tests {
Column::new_with_schema("b2", &right.schema())?,
)];
- let (columns, batches) = join_collect(left, right, on, &JoinType::Inner).await?;
+ let (columns, batches) =
+ join_collect(left, right, on, &JoinType::Inner, false).await?;
assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b2", "c2"]);
@@ -1182,7 +1294,8 @@ mod tests {
),
];
- let (columns, batches) = join_collect(left, right, on, &JoinType::Inner).await?;
+ let (columns, batches) =
+ join_collect(left, right, on, &JoinType::Inner, false).await?;
assert_eq!(columns, vec!["a1", "b2", "c1", "a1", "b2", "c2"]);
@@ -1234,7 +1347,8 @@ mod tests {
),
];
- let (columns, batches) = join_collect(left, right, on, &JoinType::Inner).await?;
+ let (columns, batches) =
+ join_collect(left, right, on, &JoinType::Inner, false).await?;
assert_eq!(columns, vec!["a1", "b2", "c1", "a1", "b2", "c2"]);
@@ -1281,7 +1395,7 @@ mod tests {
Column::new_with_schema("b1", &right.schema())?,
)];
- let join = join(left, right, on, &JoinType::Inner)?;
+ let join = join(left, right, on, &JoinType::Inner, false)?;
let columns = columns(&join.schema());
assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b1", "c2"]);
@@ -1347,7 +1461,7 @@ mod tests {
Column::new_with_schema("b1", &right.schema()).unwrap(),
)];
- let join = join(left, right, on, &JoinType::Left).unwrap();
+ let join = join(left, right, on, &JoinType::Left, false).unwrap();
let columns = columns(&join.schema());
assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b1", "c2"]);
@@ -1388,7 +1502,7 @@ mod tests {
Column::new_with_schema("b2", &right.schema()).unwrap(),
)];
- let join = join(left, right, on, &JoinType::Full).unwrap();
+ let join = join(left, right, on, &JoinType::Full, false).unwrap();
let columns = columns(&join.schema());
assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b2", "c2"]);
@@ -1427,7 +1541,7 @@ mod tests {
)];
let schema = right.schema();
let right = Arc::new(MemoryExec::try_new(&[vec![right]], schema, None).unwrap());
- let join = join(left, right, on, &JoinType::Left).unwrap();
+ let join = join(left, right, on, &JoinType::Left, false).unwrap();
let columns = columns(&join.schema());
assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b1", "c2"]);
@@ -1462,7 +1576,7 @@ mod tests {
)];
let schema = right.schema();
let right = Arc::new(MemoryExec::try_new(&[vec![right]], schema, None).unwrap());
- let join = join(left, right, on, &JoinType::Full).unwrap();
+ let join = join(left, right, on, &JoinType::Full, false).unwrap();
let columns = columns(&join.schema());
assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b2", "c2"]);
@@ -1500,9 +1614,14 @@ mod tests {
Column::new_with_schema("b1", &right.schema())?,
)];
- let (columns, batches) =
- join_collect(left.clone(), right.clone(), on.clone(), &JoinType::Left)
- .await?;
+ let (columns, batches) = join_collect(
+ left.clone(),
+ right.clone(),
+ on.clone(),
+ &JoinType::Left,
+ false,
+ )
+ .await?;
assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b1", "c2"]);
let expected = vec![
@@ -1541,6 +1660,7 @@ mod tests {
right.clone(),
on.clone(),
&JoinType::Left,
+ false,
)
.await?;
assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b1", "c2"]);
@@ -1576,7 +1696,7 @@ mod tests {
Column::new_with_schema("b1", &right.schema())?,
)];
- let join = join(left, right, on, &JoinType::Semi)?;
+ let join = join(left, right, on, &JoinType::Semi, false)?;
let columns = columns(&join.schema());
assert_eq!(columns, vec!["a1", "b1", "c1"]);
@@ -1615,7 +1735,7 @@ mod tests {
Column::new_with_schema("b1", &right.schema())?,
)];
- let join = join(left, right, on, &JoinType::Anti)?;
+ let join = join(left, right, on, &JoinType::Anti, false)?;
let columns = columns(&join.schema());
assert_eq!(columns, vec!["a1", "b1", "c1"]);
@@ -1652,7 +1772,8 @@ mod tests {
Column::new_with_schema("b1", &right.schema())?,
)];
- let (columns, batches) = join_collect(left, right, on, &JoinType::Right).await?;
+ let (columns, batches) =
+ join_collect(left, right, on, &JoinType::Right, false).await?;
assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b1", "c2"]);
@@ -1689,7 +1810,7 @@ mod tests {
)];
let (columns, batches) =
- partitioned_join_collect(left, right, on, &JoinType::Right).await?;
+ partitioned_join_collect(left, right, on, &JoinType::Right, false).await?;
assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b1", "c2"]);
@@ -1725,7 +1846,7 @@ mod tests {
Column::new_with_schema("b2", &right.schema()).unwrap(),
)];
- let join = join(left, right, on, &JoinType::Full)?;
+ let join = join(left, right, on, &JoinType::Full, false)?;
let columns = columns(&join.schema());
assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b2", "c2"]);
@@ -1780,6 +1901,7 @@ mod tests {
&[Column::new("a", 0)],
&[Column::new("a", 0)],
&random_state,
+ &false,
)?;
let mut left_ids = UInt64Builder::new(0);
@@ -1815,7 +1937,7 @@ mod tests {
Column::new_with_schema("b", &right.schema()).unwrap(),
)];
- let join = join(left, right, on, &JoinType::Inner)?;
+ let join = join(left, right, on, &JoinType::Inner, false)?;
let columns = columns(&join.schema());
assert_eq!(columns, vec!["a", "b", "c", "a", "b", "c"]);
diff --git a/datafusion/src/physical_plan/planner.rs b/datafusion/src/physical_plan/planner.rs
index 87fc166..12e7401 100644
--- a/datafusion/src/physical_plan/planner.rs
+++ b/datafusion/src/physical_plan/planner.rs
@@ -713,6 +713,7 @@ impl DefaultPhysicalPlanner {
right,
on: keys,
join_type,
+ null_equals_null,
..
} => {
let left_df_schema = left.schema();
@@ -761,6 +762,7 @@ impl DefaultPhysicalPlanner {
join_on,
join_type,
PartitionMode::Partitioned,
+ null_equals_null,
)?))
} else {
Ok(Arc::new(HashJoinExec::try_new(
@@ -769,6 +771,7 @@ impl DefaultPhysicalPlanner {
join_on,
join_type,
PartitionMode::CollectLeft,
+ null_equals_null,
)?))
}
}
diff --git a/datafusion/src/sql/planner.rs b/datafusion/src/sql/planner.rs
index 1653cb5..de09632 100644
--- a/datafusion/src/sql/planner.rs
+++ b/datafusion/src/sql/planner.rs
@@ -192,23 +192,31 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
left,
right,
all,
- } => match (op, all) {
+ } => {
+ let left_plan = self.set_expr_to_plan(left.as_ref(), None, ctes)?;
+ let right_plan = self.set_expr_to_plan(right.as_ref(), None, ctes)?;
+ match (op, all) {
(SetOperator::Union, true) => {
- let left_plan = self.set_expr_to_plan(left.as_ref(), None, ctes)?;
- let right_plan = self.set_expr_to_plan(right.as_ref(), None, ctes)?;
union_with_alias(left_plan, right_plan, alias)
}
(SetOperator::Union, false) => {
- let left_plan = self.set_expr_to_plan(left.as_ref(), None, ctes)?;
- let right_plan = self.set_expr_to_plan(right.as_ref(), None, ctes)?;
let union_plan = union_with_alias(left_plan, right_plan, alias)?;
LogicalPlanBuilder::from(union_plan).distinct()?.build()
}
+ (SetOperator::Intersect, true) => {
+ let join_keys = left_plan.schema().fields().iter().zip(right_plan.schema().fields().iter()).map(|(left_field, right_field)| ((Column::from_name(left_field.name())), (Column::from_name(right_field.name())))).unzip();
+ LogicalPlanBuilder::from(left_plan).join_detailed(&right_plan, JoinType::Semi, join_keys, true)?.build()
+ }
+ (SetOperator::Intersect, false) => {
+ let join_keys = left_plan.schema().fields().iter().zip(right_plan.schema().fields().iter()).map(|(left_field, right_field)| ((Column::from_name(left_field.name())), (Column::from_name(right_field.name())))).unzip();
+ LogicalPlanBuilder::from(left_plan).distinct()?.join_detailed(&right_plan, JoinType::Semi, join_keys, true)?.build()
+ }
_ => Err(DataFusionError::NotImplemented(format!(
- "Only UNION ALL and UNION [DISTINCT] are supported, found {}",
+ "Only UNION ALL and UNION [DISTINCT] and INTERSECT and INTERSECT [DISTINCT] are supported, found {}",
op
))),
- },
+ }
+ }
_ => Err(DataFusionError::NotImplemented(format!(
"Query {} not implemented yet",
set_expr
@@ -3543,11 +3551,11 @@ mod tests {
}
#[test]
- fn only_union_all_supported() {
+ fn except_not_supported() {
let sql = "SELECT order_id from orders EXCEPT SELECT order_id FROM orders";
let err = logical_plan(sql).expect_err("query should have failed");
assert_eq!(
- "NotImplemented(\"Only UNION ALL and UNION [DISTINCT] are supported, found EXCEPT\")",
+ "NotImplemented(\"Only UNION ALL and UNION [DISTINCT] and INTERSECT and INTERSECT [DISTINCT] are supported, found EXCEPT\")",
format!("{:?}", err)
);
}
diff --git a/datafusion/tests/sql.rs b/datafusion/tests/sql.rs
index 3af5501..996908a 100644
--- a/datafusion/tests/sql.rs
+++ b/datafusion/tests/sql.rs
@@ -5580,3 +5580,68 @@ async fn query_nested_get_indexed_field_on_struct() -> Result<()> {
assert_eq!(expected, actual);
Ok(())
}
+
+#[tokio::test]
+async fn intersect_with_null_not_equal() {
+ let sql = "SELECT * FROM (SELECT null AS id1, 1 AS id2) t1
+ INTERSECT SELECT * FROM (SELECT null AS id1, 2 AS id2) t2";
+
+ let expected: &[&[&str]] = &[];
+
+ let mut ctx = create_join_context_qualified().unwrap();
+ let actual = execute(&mut ctx, sql).await;
+
+ assert_eq!(expected, actual);
+}
+
+#[tokio::test]
+async fn intersect_with_null_equal() {
+ let sql = "SELECT * FROM (SELECT null AS id1, 1 AS id2) t1
+ INTERSECT SELECT * FROM (SELECT null AS id1, 1 AS id2) t2";
+
+ let expected: Vec<Vec<String>> = vec![vec!["NULL".to_string(), "1".to_string()]];
+
+ let mut ctx = create_join_context_qualified().unwrap();
+ let actual = execute(&mut ctx, sql).await;
+
+ assert_eq!(expected, actual);
+}
+
+#[tokio::test]
+async fn test_intersect_all() -> Result<()> {
+ let mut ctx = ExecutionContext::new();
+ register_alltypes_parquet(&mut ctx).await;
+ // execute the query
+ let sql = "SELECT int_col, double_col FROM alltypes_plain where int_col > 0 INTERSECT ALL SELECT int_col, double_col FROM alltypes_plain LIMIT 4";
+ let actual = execute_to_batches(&mut ctx, sql).await;
+ let expected = vec![
+ "+---------+------------+",
+ "| int_col | double_col |",
+ "+---------+------------+",
+ "| 1 | 10.1 |",
+ "| 1 | 10.1 |",
+ "| 1 | 10.1 |",
+ "| 1 | 10.1 |",
+ "+---------+------------+",
+ ];
+ assert_batches_eq!(expected, &actual);
+ Ok(())
+}
+
+#[tokio::test]
+async fn test_intersect_distinct() -> Result<()> {
+ let mut ctx = ExecutionContext::new();
+ register_alltypes_parquet(&mut ctx).await;
+ // execute the query
+ let sql = "SELECT int_col, double_col FROM alltypes_plain where int_col > 0 INTERSECT SELECT int_col, double_col FROM alltypes_plain";
+ let actual = execute_to_batches(&mut ctx, sql).await;
+ let expected = vec![
+ "+---------+------------+",
+ "| int_col | double_col |",
+ "+---------+------------+",
+ "| 1 | 10.1 |",
+ "+---------+------------+",
+ ];
+ assert_batches_eq!(expected, &actual);
+ Ok(())
+}