You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by al...@apache.org on 2021/06/02 22:14:52 UTC
[arrow-datafusion] branch master updated: Support semi join (#470)
This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/master by this push:
new 33ff660 Support semi join (#470)
33ff660 is described below
commit 33ff660318bf60d0c9aa45ceba2c4c943bfe9438
Author: Daniƫl Heres <da...@gmail.com>
AuthorDate: Thu Jun 3 00:14:41 2021 +0200
Support semi join (#470)
* Support semi join
* Fmt
* Match on Semi
* Simplify
* Fmt
* Undo match
* Update datafusion/src/physical_plan/hash_join.rs
Co-authored-by: Andrew Lamb <an...@nerdnetworks.org>
* Add item on the left for semi join
* Simplify pattern match
Co-authored-by: Andrew Lamb <an...@nerdnetworks.org>
---
ballista/rust/core/proto/ballista.proto | 1 +
.../rust/core/src/serde/logical_plan/from_proto.rs | 1 +
.../rust/core/src/serde/logical_plan/to_proto.rs | 1 +
.../core/src/serde/physical_plan/from_proto.rs | 1 +
.../rust/core/src/serde/physical_plan/to_proto.rs | 1 +
datafusion/src/logical_plan/builder.rs | 4 +
datafusion/src/logical_plan/plan.rs | 4 +-
datafusion/src/optimizer/hash_build_probe_order.rs | 10 ++-
datafusion/src/physical_plan/hash_join.rs | 89 ++++++++++++++++++----
datafusion/src/physical_plan/hash_utils.rs | 3 +
datafusion/src/physical_plan/planner.rs | 1 +
11 files changed, 99 insertions(+), 17 deletions(-)
diff --git a/ballista/rust/core/proto/ballista.proto b/ballista/rust/core/proto/ballista.proto
index da0c615..0387214 100644
--- a/ballista/rust/core/proto/ballista.proto
+++ b/ballista/rust/core/proto/ballista.proto
@@ -363,6 +363,7 @@ enum JoinType {
LEFT = 1;
RIGHT = 2;
FULL = 3;
+ SEMI = 4;
}
message JoinNode {
diff --git a/ballista/rust/core/src/serde/logical_plan/from_proto.rs b/ballista/rust/core/src/serde/logical_plan/from_proto.rs
index 10c4670..4847126 100644
--- a/ballista/rust/core/src/serde/logical_plan/from_proto.rs
+++ b/ballista/rust/core/src/serde/logical_plan/from_proto.rs
@@ -265,6 +265,7 @@ impl TryInto<LogicalPlan> for &protobuf::LogicalPlanNode {
protobuf::JoinType::Left => JoinType::Left,
protobuf::JoinType::Right => JoinType::Right,
protobuf::JoinType::Full => JoinType::Full,
+ protobuf::JoinType::Semi => JoinType::Semi,
};
LogicalPlanBuilder::from(&convert_box_required!(join.left)?)
.join(
diff --git a/ballista/rust/core/src/serde/logical_plan/to_proto.rs b/ballista/rust/core/src/serde/logical_plan/to_proto.rs
index b630dfc..e1c0c5e 100644
--- a/ballista/rust/core/src/serde/logical_plan/to_proto.rs
+++ b/ballista/rust/core/src/serde/logical_plan/to_proto.rs
@@ -834,6 +834,7 @@ impl TryInto<protobuf::LogicalPlanNode> for &LogicalPlan {
JoinType::Left => protobuf::JoinType::Left,
JoinType::Right => protobuf::JoinType::Right,
JoinType::Full => protobuf::JoinType::Full,
+ JoinType::Semi => protobuf::JoinType::Semi,
};
let left_join_column = on.iter().map(|on| on.0.to_owned()).collect();
let right_join_column = on.iter().map(|on| on.1.to_owned()).collect();
diff --git a/ballista/rust/core/src/serde/physical_plan/from_proto.rs b/ballista/rust/core/src/serde/physical_plan/from_proto.rs
index 2039def..7f98a83 100644
--- a/ballista/rust/core/src/serde/physical_plan/from_proto.rs
+++ b/ballista/rust/core/src/serde/physical_plan/from_proto.rs
@@ -379,6 +379,7 @@ impl TryInto<Arc<dyn ExecutionPlan>> for &protobuf::PhysicalPlanNode {
protobuf::JoinType::Left => JoinType::Left,
protobuf::JoinType::Right => JoinType::Right,
protobuf::JoinType::Full => JoinType::Full,
+ protobuf::JoinType::Semi => JoinType::Semi,
};
Ok(Arc::new(HashJoinExec::try_new(
left,
diff --git a/ballista/rust/core/src/serde/physical_plan/to_proto.rs b/ballista/rust/core/src/serde/physical_plan/to_proto.rs
index 9571f3d..c409f94 100644
--- a/ballista/rust/core/src/serde/physical_plan/to_proto.rs
+++ b/ballista/rust/core/src/serde/physical_plan/to_proto.rs
@@ -133,6 +133,7 @@ impl TryInto<protobuf::PhysicalPlanNode> for Arc<dyn ExecutionPlan> {
JoinType::Left => protobuf::JoinType::Left,
JoinType::Right => protobuf::JoinType::Right,
JoinType::Full => protobuf::JoinType::Full,
+ JoinType::Semi => protobuf::JoinType::Semi,
};
Ok(protobuf::PhysicalPlanNode {
physical_plan_type: Some(PhysicalPlanType::HashJoin(Box::new(
diff --git a/datafusion/src/logical_plan/builder.rs b/datafusion/src/logical_plan/builder.rs
index c02555d..71de48c 100644
--- a/datafusion/src/logical_plan/builder.rs
+++ b/datafusion/src/logical_plan/builder.rs
@@ -410,6 +410,10 @@ fn build_join_schema(
// left then right
left_fields.chain(right_fields).cloned().collect()
}
+ JoinType::Semi => {
+ // Only use the left side for the schema
+ left.fields().clone()
+ }
JoinType::Right => {
// remove left-side join keys if they have the same names as the right-side
let duplicate_keys = &on
diff --git a/datafusion/src/logical_plan/plan.rs b/datafusion/src/logical_plan/plan.rs
index 2d85abb..5cb94be 100644
--- a/datafusion/src/logical_plan/plan.rs
+++ b/datafusion/src/logical_plan/plan.rs
@@ -34,7 +34,7 @@ use std::{
};
/// Join type
-#[derive(Debug, Clone, Copy)]
+#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum JoinType {
/// Inner Join
Inner,
@@ -44,6 +44,8 @@ pub enum JoinType {
Right,
/// Full Join
Full,
+ /// Semi Join
+ Semi,
}
/// A LogicalPlan represents the different types of relational
diff --git a/datafusion/src/optimizer/hash_build_probe_order.rs b/datafusion/src/optimizer/hash_build_probe_order.rs
index 100ae4f..86d38ef 100644
--- a/datafusion/src/optimizer/hash_build_probe_order.rs
+++ b/datafusion/src/optimizer/hash_build_probe_order.rs
@@ -106,6 +106,13 @@ fn should_swap_join_order(left: &LogicalPlan, right: &LogicalPlan) -> bool {
}
}
+fn supports_swap(join_type: JoinType) -> bool {
+ match join_type {
+ JoinType::Inner | JoinType::Left | JoinType::Right | JoinType::Full => true,
+ JoinType::Semi => false,
+ }
+}
+
impl OptimizerRule for HashBuildProbeOrder {
fn name(&self) -> &str {
"hash_build_probe_order"
@@ -128,7 +135,7 @@ impl OptimizerRule for HashBuildProbeOrder {
} => {
let left = self.optimize(left, execution_props)?;
let right = self.optimize(right, execution_props)?;
- if should_swap_join_order(&left, &right) {
+ if should_swap_join_order(&left, &right) && supports_swap(*join_type) {
// Swap left and right, change join type and (equi-)join key order
Ok(LogicalPlan::Join {
left: Arc::new(right),
@@ -216,6 +223,7 @@ fn swap_join_type(join_type: JoinType) -> JoinType {
JoinType::Full => JoinType::Full,
JoinType::Left => JoinType::Right,
JoinType::Right => JoinType::Left,
+ _ => unreachable!(),
}
}
diff --git a/datafusion/src/physical_plan/hash_join.rs b/datafusion/src/physical_plan/hash_join.rs
index 01551cd..6653b9a 100644
--- a/datafusion/src/physical_plan/hash_join.rs
+++ b/datafusion/src/physical_plan/hash_join.rs
@@ -184,7 +184,7 @@ impl HashJoinExec {
/// Calculates column indices and left/right placement on input / output schemas and jointype
fn column_indices_from_schema(&self) -> ArrowResult<Vec<ColumnIndex>> {
let (primary_is_left, primary_schema, secondary_schema) = match self.join_type {
- JoinType::Inner | JoinType::Left | JoinType::Full => {
+ JoinType::Inner | JoinType::Left | JoinType::Full | JoinType::Semi => {
(true, self.left.schema(), self.right.schema())
}
JoinType::Right => (false, self.right.schema(), self.left.schema()),
@@ -376,7 +376,7 @@ impl ExecutionPlan for HashJoinExec {
let column_indices = self.column_indices_from_schema()?;
let num_rows = left_data.1.num_rows();
let visited_left_side = match self.join_type {
- JoinType::Left | JoinType::Full => vec![false; num_rows],
+ JoinType::Left | JoinType::Full | JoinType::Semi => vec![false; num_rows],
JoinType::Inner | JoinType::Right => vec![],
};
Ok(Box::pin(HashJoinStream {
@@ -544,6 +544,13 @@ fn build_batch(
)
.unwrap();
+ if join_type == JoinType::Semi {
+ return Ok((
+ RecordBatch::new_empty(Arc::new(schema.clone())),
+ left_indices,
+ ));
+ }
+
build_batch_from_indices(
schema,
&left_data.1,
@@ -606,7 +613,7 @@ fn build_join_indexes(
let left = &left_data.0;
match join_type {
- JoinType::Inner => {
+ JoinType::Inner | JoinType::Semi => {
// Using a buffer builder to avoid slower normal builder
let mut left_indices = UInt64BufferBuilder::new(0);
let mut right_indices = UInt32BufferBuilder::new(0);
@@ -1108,23 +1115,35 @@ pub fn create_hashes<'a>(
Ok(hashes_buffer)
}
-// Produces a batch for left-side rows that are not marked as being visited during the whole join
-fn produce_unmatched(
+// Produces a batch for left-side rows that have/have not been matched during the whole join
+fn produce_from_matched(
visited_left_side: &[bool],
schema: &SchemaRef,
column_indices: &[ColumnIndex],
left_data: &JoinLeftData,
+ unmatched: bool,
) -> ArrowResult<RecordBatch> {
// Find indices which didn't match any right row (are false)
- let unmatched_indices: Vec<u64> = visited_left_side
- .iter()
- .enumerate()
- .filter(|&(_, &value)| !value)
- .map(|(index, _)| index as u64)
- .collect();
+ let indices = if unmatched {
+ UInt64Array::from_iter_values(
+ visited_left_side
+ .iter()
+ .enumerate()
+ .filter(|&(_, &value)| !value)
+ .map(|(index, _)| index as u64),
+ )
+ } else {
+ // produce those that did match
+ UInt64Array::from_iter_values(
+ visited_left_side
+ .iter()
+ .enumerate()
+ .filter(|&(_, &value)| value)
+ .map(|(index, _)| index as u64),
+ )
+ };
// generate batches by taking values from the left side and generating columns filled with null on the right side
- let indices = UInt64Array::from_iter_values(unmatched_indices);
let num_rows = indices.len();
let mut columns: Vec<Arc<dyn Array>> = Vec::with_capacity(schema.fields().len());
for (idx, column_index) in column_indices.iter().enumerate() {
@@ -1171,7 +1190,7 @@ impl Stream for HashJoinStream {
self.num_output_rows += batch.num_rows();
match self.join_type {
- JoinType::Left | JoinType::Full => {
+ JoinType::Left | JoinType::Full | JoinType::Semi => {
left_side.iter().flatten().for_each(|x| {
self.visited_left_side[x as usize] = true;
});
@@ -1185,12 +1204,15 @@ impl Stream for HashJoinStream {
let start = Instant::now();
// For the left join, produce rows for unmatched rows
match self.join_type {
- JoinType::Left | JoinType::Full if !self.is_exhausted => {
- let result = produce_unmatched(
+ JoinType::Left | JoinType::Full | JoinType::Semi
+ if !self.is_exhausted =>
+ {
+ let result = produce_from_matched(
&self.visited_left_side,
&self.schema,
&self.column_indices,
&self.left_data,
+ self.join_type != JoinType::Semi,
);
if let Ok(ref batch) = result {
self.num_input_batches += 1;
@@ -1207,6 +1229,7 @@ impl Stream for HashJoinStream {
}
JoinType::Left
| JoinType::Full
+ | JoinType::Semi
| JoinType::Inner
| JoinType::Right => {}
}
@@ -1667,6 +1690,42 @@ mod tests {
}
#[tokio::test]
+ async fn join_semi() -> Result<()> {
+ let left = build_table(
+ ("a1", &vec![1, 2, 2, 3]),
+ ("b1", &vec![4, 5, 5, 7]), // 7 does not exist on the right
+ ("c1", &vec![7, 8, 8, 9]),
+ );
+ let right = build_table(
+ ("a2", &vec![10, 20, 30, 40]),
+ ("b1", &vec![4, 5, 6, 5]), // 5 is double on the right
+ ("c2", &vec![70, 80, 90, 100]),
+ );
+ let on = &[("b1", "b1")];
+
+ let join = join(left, right, on, &JoinType::Semi)?;
+
+ let columns = columns(&join.schema());
+ assert_eq!(columns, vec!["a1", "b1", "c1"]);
+
+ let stream = join.execute(0).await?;
+ let batches = common::collect(stream).await?;
+
+ let expected = vec![
+ "+----+----+----+",
+ "| a1 | b1 | c1 |",
+ "+----+----+----+",
+ "| 1 | 4 | 7 |",
+ "| 2 | 5 | 8 |",
+ "| 2 | 5 | 8 |",
+ "+----+----+----+",
+ ];
+ assert_batches_sorted_eq!(expected, &batches);
+
+ Ok(())
+ }
+
+ #[tokio::test]
async fn join_right_one() -> Result<()> {
let left = build_table(
("a1", &vec![1, 2, 3]),
diff --git a/datafusion/src/physical_plan/hash_utils.rs b/datafusion/src/physical_plan/hash_utils.rs
index 7e030af..110319e 100644
--- a/datafusion/src/physical_plan/hash_utils.rs
+++ b/datafusion/src/physical_plan/hash_utils.rs
@@ -32,6 +32,8 @@ pub enum JoinType {
Right,
/// Full Join
Full,
+ /// Semi Join
+ Semi,
}
/// The on clause of the join, as vector of (left, right) columns.
@@ -130,6 +132,7 @@ pub fn build_join_schema(
// left then right
left_fields.chain(right_fields).cloned().collect()
}
+ JoinType::Semi => left.fields().clone(),
};
Schema::new(fields)
}
diff --git a/datafusion/src/physical_plan/planner.rs b/datafusion/src/physical_plan/planner.rs
index 7ddfaf8..4971a02 100644
--- a/datafusion/src/physical_plan/planner.rs
+++ b/datafusion/src/physical_plan/planner.rs
@@ -367,6 +367,7 @@ impl DefaultPhysicalPlanner {
JoinType::Left => hash_utils::JoinType::Left,
JoinType::Right => hash_utils::JoinType::Right,
JoinType::Full => hash_utils::JoinType::Full,
+ JoinType::Semi => hash_utils::JoinType::Semi,
};
if ctx_state.config.concurrency > 1 && ctx_state.config.repartition_joins
{