You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by dh...@apache.org on 2021/09/20 19:48:49 UTC
[arrow-datafusion] branch master updated: fix: allow duplicate
field names in table join, fix output with duplicated names (#1023)
This is an automated email from the ASF dual-hosted git repository.
dheres pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/master by this push:
new 65483d3 fix: allow duplicate field names in table join, fix output with duplicated names (#1023)
65483d3 is described below
commit 65483d32f6ee86766bb74988659bb51142a4edff
Author: QP Hou <qp...@scribd.com>
AuthorDate: Mon Sep 20 12:48:46 2021 -0700
fix: allow duplicate field names in table join, fix output with duplicated names (#1023)
* fix: allow duplicate field names in table join
* move join related code into join_utils.rs
---
datafusion/src/physical_plan/cross_join.rs | 2 +-
datafusion/src/physical_plan/hash_join.rs | 142 ++++++++++---------
datafusion/src/physical_plan/hash_utils.rs | 144 +-------------------
datafusion/src/physical_plan/join_utils.rs | 212 +++++++++++++++++++++++++++++
datafusion/src/physical_plan/mod.rs | 1 +
datafusion/src/physical_plan/planner.rs | 4 +-
datafusion/tests/sql.rs | 63 +++++++++
7 files changed, 356 insertions(+), 212 deletions(-)
diff --git a/datafusion/src/physical_plan/cross_join.rs b/datafusion/src/physical_plan/cross_join.rs
index 1575958..a70d777 100644
--- a/datafusion/src/physical_plan/cross_join.rs
+++ b/datafusion/src/physical_plan/cross_join.rs
@@ -28,7 +28,7 @@ use arrow::record_batch::RecordBatch;
use futures::{Stream, TryStreamExt};
use super::{
- coalesce_partitions::CoalescePartitionsExec, hash_utils::check_join_is_valid,
+ coalesce_partitions::CoalescePartitionsExec, join_utils::check_join_is_valid,
ColumnStatistics, Statistics,
};
use crate::{
diff --git a/datafusion/src/physical_plan/hash_join.rs b/datafusion/src/physical_plan/hash_join.rs
index f2ce88f..d7aba9e 100644
--- a/datafusion/src/physical_plan/hash_join.rs
+++ b/datafusion/src/physical_plan/hash_join.rs
@@ -52,7 +52,7 @@ use hashbrown::raw::RawTable;
use super::{
coalesce_partitions::CoalescePartitionsExec,
- hash_utils::{build_join_schema, check_join_is_valid, JoinOn},
+ join_utils::{build_join_schema, check_join_is_valid, ColumnIndex, JoinOn, JoinSide},
};
use super::{
expressions::Column,
@@ -115,6 +115,8 @@ pub struct HashJoinExec {
mode: PartitionMode,
/// Execution metrics
metrics: ExecutionPlanMetricsSet,
+ /// Information of index and left / right placement of columns
+ column_indices: Vec<ColumnIndex>,
}
/// Metrics for HashJoinExec
@@ -165,14 +167,6 @@ pub enum PartitionMode {
CollectLeft,
}
-/// Information about the index and placement (left or right) of the columns
-struct ColumnIndex {
- /// Index of the column
- index: usize,
- /// Whether the column is at the left or right side
- is_left: bool,
-}
-
impl HashJoinExec {
/// Tries to create a new [HashJoinExec].
/// # Error
@@ -188,7 +182,8 @@ impl HashJoinExec {
let right_schema = right.schema();
check_join_is_valid(&left_schema, &right_schema, &on)?;
- let schema = Arc::new(build_join_schema(&left_schema, &right_schema, join_type));
+ let (schema, column_indices) =
+ build_join_schema(&left_schema, &right_schema, join_type);
let random_state = RandomState::with_seeds(0, 0, 0, 0);
@@ -197,11 +192,12 @@ impl HashJoinExec {
right,
on,
join_type: *join_type,
- schema,
+ schema: Arc::new(schema),
build_side: Arc::new(Mutex::new(None)),
random_state,
mode: partition_mode,
metrics: ExecutionPlanMetricsSet::new(),
+ column_indices,
})
}
@@ -229,38 +225,6 @@ impl HashJoinExec {
pub fn partition_mode(&self) -> &PartitionMode {
&self.mode
}
-
- /// Calculates column indices and left/right placement on input / output schemas and jointype
- fn column_indices_from_schema(&self) -> ArrowResult<Vec<ColumnIndex>> {
- let (primary_is_left, primary_schema, secondary_schema) = match self.join_type {
- JoinType::Inner
- | JoinType::Left
- | JoinType::Full
- | JoinType::Semi
- | JoinType::Anti => (true, self.left.schema(), self.right.schema()),
- JoinType::Right => (false, self.right.schema(), self.left.schema()),
- };
- let mut column_indices = Vec::with_capacity(self.schema.fields().len());
- for field in self.schema.fields() {
- let (is_primary, index) = match primary_schema.index_of(field.name()) {
- Ok(i) => Ok((true, i)),
- Err(_) => {
- match secondary_schema.index_of(field.name()) {
- Ok(i) => Ok((false, i)),
- _ => Err(DataFusionError::Internal(
- format!("During execution, the column {} was not found in neither the left or right side of the join", field.name()).to_string()
- ))
- }
- }
- }.map_err(DataFusionError::into_arrow_external_error)?;
-
- let is_left =
- is_primary && primary_is_left || !is_primary && !primary_is_left;
- column_indices.push(ColumnIndex { index, is_left });
- }
-
- Ok(column_indices)
- }
}
#[async_trait]
@@ -421,7 +385,6 @@ impl ExecutionPlan for HashJoinExec {
let right_stream = self.right.execute(partition).await?;
let on_right = self.on.iter().map(|on| on.1.clone()).collect::<Vec<_>>();
- let column_indices = self.column_indices_from_schema()?;
let num_rows = left_data.1.num_rows();
let visited_left_side = match self.join_type {
JoinType::Left | JoinType::Full | JoinType::Semi | JoinType::Anti => {
@@ -436,7 +399,7 @@ impl ExecutionPlan for HashJoinExec {
self.join_type,
left_data,
right_stream,
- column_indices,
+ self.column_indices.clone(),
self.random_state.clone(),
visited_left_side,
HashJoinMetrics::new(partition, &self.metrics),
@@ -522,8 +485,6 @@ struct HashJoinStream {
left_data: JoinLeftData,
/// right
right: SendableRecordBatchStream,
- /// Information of index and left / right placement of columns
- column_indices: Vec<ColumnIndex>,
/// Random state used for hashing initialization
random_state: RandomState,
/// Keeps track of the left side rows whether they are visited
@@ -532,6 +493,8 @@ struct HashJoinStream {
is_exhausted: bool,
/// Metrics
join_metrics: HashJoinMetrics,
+ /// Information of index and left / right placement of columns
+ column_indices: Vec<ColumnIndex>,
}
#[allow(clippy::too_many_arguments)]
@@ -589,12 +552,15 @@ fn build_batch_from_indices(
let mut columns: Vec<Arc<dyn Array>> = Vec::with_capacity(schema.fields().len());
for column_index in column_indices {
- let array = if column_index.is_left {
- let array = left.column(column_index.index);
- compute::take(array.as_ref(), &left_indices, None)?
- } else {
- let array = right.column(column_index.index);
- compute::take(array.as_ref(), &right_indices, None)?
+ let array = match column_index.side {
+ JoinSide::Left => {
+ let array = left.column(column_index.index);
+ compute::take(array.as_ref(), &left_indices, None)?
+ }
+ JoinSide::Right => {
+ let array = right.column(column_index.index);
+ compute::take(array.as_ref(), &right_indices, None)?
+ }
};
columns.push(array);
}
@@ -861,12 +827,15 @@ fn produce_from_matched(
let num_rows = indices.len();
let mut columns: Vec<Arc<dyn Array>> = Vec::with_capacity(schema.fields().len());
for (idx, column_index) in column_indices.iter().enumerate() {
- let array = if column_index.is_left {
- let array = left_data.1.column(column_index.index);
- compute::take(array.as_ref(), &indices, None).unwrap()
- } else {
- let datatype = schema.field(idx).data_type();
- arrow::array::new_null_array(datatype, num_rows)
+ let array = match column_index.side {
+ JoinSide::Left => {
+ let array = left_data.1.column(column_index.index);
+ compute::take(array.as_ref(), &indices, None).unwrap()
+ }
+ JoinSide::Right => {
+ let datatype = schema.field(idx).data_type();
+ arrow::array::new_null_array(datatype, num_rows)
+ }
};
columns.push(array);
@@ -1375,7 +1344,7 @@ mod tests {
"| 1 | 4 | 7 | 10 | 4 | 70 |",
"| 2 | 5 | 8 | 20 | 5 | 80 |",
"| 2 | 5 | 8 | 20 | 5 | 80 |",
- "| 3 | 7 | 9 | | 7 | |",
+ "| 3 | 7 | 9 | | | |",
"+----+----+----+----+----+----+",
];
@@ -1451,9 +1420,9 @@ mod tests {
"+----+----+----+----+----+----+",
"| a1 | b1 | c1 | a2 | b1 | c2 |",
"+----+----+----+----+----+----+",
- "| 1 | 4 | 7 | | 4 | |",
- "| 2 | 5 | 8 | | 5 | |",
- "| 3 | 7 | 9 | | 7 | |",
+ "| 1 | 4 | 7 | | | |",
+ "| 2 | 5 | 8 | | | |",
+ "| 3 | 7 | 9 | | | |",
"+----+----+----+----+----+----+",
];
@@ -1523,7 +1492,7 @@ mod tests {
"+----+----+----+----+----+----+",
"| 1 | 4 | 7 | 10 | 4 | 70 |",
"| 2 | 5 | 8 | 20 | 5 | 80 |",
- "| 3 | 7 | 9 | | 7 | |",
+ "| 3 | 7 | 9 | | | |",
"+----+----+----+----+----+----+",
];
assert_batches_sorted_eq!(expected, &batches);
@@ -1563,7 +1532,7 @@ mod tests {
"+----+----+----+----+----+----+",
"| 1 | 4 | 7 | 10 | 4 | 70 |",
"| 2 | 5 | 8 | 20 | 5 | 80 |",
- "| 3 | 7 | 9 | | 7 | |",
+ "| 3 | 7 | 9 | | | |",
"+----+----+----+----+----+----+",
];
assert_batches_sorted_eq!(expected, &batches);
@@ -1672,7 +1641,7 @@ mod tests {
"+----+----+----+----+----+----+",
"| a1 | b1 | c1 | a2 | b1 | c2 |",
"+----+----+----+----+----+----+",
- "| | 6 | | 30 | 6 | 90 |",
+ "| | | | 30 | 6 | 90 |",
"| 1 | 4 | 7 | 10 | 4 | 70 |",
"| 2 | 5 | 8 | 20 | 5 | 80 |",
"+----+----+----+----+----+----+",
@@ -1709,7 +1678,7 @@ mod tests {
"+----+----+----+----+----+----+",
"| a1 | b1 | c1 | a2 | b1 | c2 |",
"+----+----+----+----+----+----+",
- "| | 6 | | 30 | 6 | 90 |",
+ "| | | | 30 | 6 | 90 |",
"| 1 | 4 | 7 | 10 | 4 | 70 |",
"| 2 | 5 | 8 | 20 | 5 | 80 |",
"+----+----+----+----+----+----+",
@@ -1808,4 +1777,43 @@ mod tests {
Ok(())
}
+
+ #[tokio::test]
+ async fn join_with_duplicated_column_names() -> Result<()> {
+ let left = build_table(
+ ("a", &vec![1, 2, 3]),
+ ("b", &vec![4, 5, 7]),
+ ("c", &vec![7, 8, 9]),
+ );
+ let right = build_table(
+ ("a", &vec![10, 20, 30]),
+ ("b", &vec![1, 2, 7]),
+ ("c", &vec![70, 80, 90]),
+ );
+ let on = vec![(
+ // join on a=b so there are duplicate column names on unjoined columns
+ Column::new_with_schema("a", &left.schema()).unwrap(),
+ Column::new_with_schema("b", &right.schema()).unwrap(),
+ )];
+
+ let join = join(left, right, on, &JoinType::Inner)?;
+
+ let columns = columns(&join.schema());
+ assert_eq!(columns, vec!["a", "b", "c", "a", "b", "c"]);
+
+ let stream = join.execute(0).await?;
+ let batches = common::collect(stream).await?;
+
+ let expected = vec![
+ "+---+---+---+----+---+----+",
+ "| a | b | c | a | b | c |",
+ "+---+---+---+----+---+----+",
+ "| 1 | 4 | 7 | 10 | 1 | 70 |",
+ "| 2 | 5 | 8 | 20 | 2 | 80 |",
+ "+---+---+---+----+---+----+",
+ ];
+ assert_batches_sorted_eq!(expected, &batches);
+
+ Ok(())
+ }
}
diff --git a/datafusion/src/physical_plan/hash_utils.rs b/datafusion/src/physical_plan/hash_utils.rs
index 6a622df..fbd0c97 100644
--- a/datafusion/src/physical_plan/hash_utils.rs
+++ b/datafusion/src/physical_plan/hash_utils.rs
@@ -26,92 +26,11 @@ use arrow::array::{
TimestampNanosecondArray, UInt16Array, UInt32Array, UInt64Array, UInt8Array,
};
use arrow::datatypes::{
- ArrowDictionaryKeyType, ArrowNativeType, DataType, Field, Int16Type, Int32Type,
- Int64Type, Int8Type, Schema, TimeUnit, UInt16Type, UInt32Type, UInt64Type, UInt8Type,
+ ArrowDictionaryKeyType, ArrowNativeType, DataType, Int16Type, Int32Type, Int64Type,
+ Int8Type, TimeUnit, UInt16Type, UInt32Type, UInt64Type, UInt8Type,
};
-use std::collections::HashSet;
use std::sync::Arc;
-use crate::logical_plan::JoinType;
-use crate::physical_plan::expressions::Column;
-
-/// The on clause of the join, as vector of (left, right) columns.
-pub type JoinOn = Vec<(Column, Column)>;
-/// Reference for JoinOn.
-pub type JoinOnRef<'a> = &'a [(Column, Column)];
-
-/// Checks whether the schemas "left" and "right" and columns "on" represent a valid join.
-/// They are valid whenever their columns' intersection equals the set `on`
-pub fn check_join_is_valid(left: &Schema, right: &Schema, on: JoinOnRef) -> Result<()> {
- let left: HashSet<Column> = left
- .fields()
- .iter()
- .enumerate()
- .map(|(idx, f)| Column::new(f.name(), idx))
- .collect();
- let right: HashSet<Column> = right
- .fields()
- .iter()
- .enumerate()
- .map(|(idx, f)| Column::new(f.name(), idx))
- .collect();
-
- check_join_set_is_valid(&left, &right, on)
-}
-
-/// Checks whether the sets left, right and on compose a valid join.
-/// They are valid whenever their intersection equals the set `on`
-fn check_join_set_is_valid(
- left: &HashSet<Column>,
- right: &HashSet<Column>,
- on: &[(Column, Column)],
-) -> Result<()> {
- let on_left = &on.iter().map(|on| on.0.clone()).collect::<HashSet<_>>();
- let left_missing = on_left.difference(left).collect::<HashSet<_>>();
-
- let on_right = &on.iter().map(|on| on.1.clone()).collect::<HashSet<_>>();
- let right_missing = on_right.difference(right).collect::<HashSet<_>>();
-
- if !left_missing.is_empty() | !right_missing.is_empty() {
- return Err(DataFusionError::Plan(format!(
- "The left or right side of the join does not have all columns on \"on\": \nMissing on the left: {:?}\nMissing on the right: {:?}",
- left_missing,
- right_missing,
- )));
- };
-
- let remaining = right
- .difference(on_right)
- .cloned()
- .collect::<HashSet<Column>>();
-
- let collisions = left.intersection(&remaining).collect::<HashSet<_>>();
-
- if !collisions.is_empty() {
- return Err(DataFusionError::Plan(format!(
- "The left schema and the right schema have the following columns with the same name without being on the ON statement: {:?}. Consider aliasing them.",
- collisions,
- )));
- };
-
- Ok(())
-}
-
-/// Creates a schema for a join operation.
-/// The fields from the left side are first
-pub fn build_join_schema(left: &Schema, right: &Schema, join_type: &JoinType) -> Schema {
- let fields: Vec<Field> = match join_type {
- JoinType::Inner | JoinType::Left | JoinType::Full | JoinType::Right => {
- let left_fields = left.fields().iter();
- let right_fields = right.fields().iter();
- // left then right
- left_fields.chain(right_fields).cloned().collect()
- }
- JoinType::Semi | JoinType::Anti => left.fields().clone(),
- };
- Schema::new(fields)
-}
-
// Combines two hashes into one hash
#[inline]
fn combine_hashes(l: u64, r: u64) -> u64 {
@@ -602,65 +521,6 @@ mod tests {
use super::*;
- fn check(left: &[Column], right: &[Column], on: &[(Column, Column)]) -> Result<()> {
- let left = left
- .iter()
- .map(|x| x.to_owned())
- .collect::<HashSet<Column>>();
- let right = right
- .iter()
- .map(|x| x.to_owned())
- .collect::<HashSet<Column>>();
- check_join_set_is_valid(&left, &right, on)
- }
-
- #[test]
- fn check_valid() -> Result<()> {
- let left = vec![Column::new("a", 0), Column::new("b1", 1)];
- let right = vec![Column::new("a", 0), Column::new("b2", 1)];
- let on = &[(Column::new("a", 0), Column::new("a", 0))];
-
- check(&left, &right, on)?;
- Ok(())
- }
-
- #[test]
- fn check_not_in_right() {
- let left = vec![Column::new("a", 0), Column::new("b", 1)];
- let right = vec![Column::new("b", 0)];
- let on = &[(Column::new("a", 0), Column::new("a", 0))];
-
- assert!(check(&left, &right, on).is_err());
- }
-
- #[test]
- fn check_not_in_left() {
- let left = vec![Column::new("b", 0)];
- let right = vec![Column::new("a", 0)];
- let on = &[(Column::new("a", 0), Column::new("a", 0))];
-
- assert!(check(&left, &right, on).is_err());
- }
-
- #[test]
- fn check_collision() {
- // column "a" would appear both in left and right
- let left = vec![Column::new("a", 0), Column::new("c", 1)];
- let right = vec![Column::new("a", 0), Column::new("b", 1)];
- let on = &[(Column::new("a", 0), Column::new("b", 1))];
-
- assert!(check(&left, &right, on).is_err());
- }
-
- #[test]
- fn check_in_right() {
- let left = vec![Column::new("a", 0), Column::new("c", 1)];
- let right = vec![Column::new("b", 0)];
- let on = &[(Column::new("a", 0), Column::new("b", 0))];
-
- assert!(check(&left, &right, on).is_ok());
- }
-
#[test]
fn create_hashes_for_float_arrays() -> Result<()> {
let f32_arr = Arc::new(Float32Array::from(vec![0.12, 0.5, 1f32, 444.7]));
diff --git a/datafusion/src/physical_plan/join_utils.rs b/datafusion/src/physical_plan/join_utils.rs
new file mode 100644
index 0000000..8359bbc
--- /dev/null
+++ b/datafusion/src/physical_plan/join_utils.rs
@@ -0,0 +1,212 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+//! Join related functionality used both on logical and physical plans
+
+use crate::error::{DataFusionError, Result};
+use crate::logical_plan::JoinType;
+use crate::physical_plan::expressions::Column;
+use arrow::datatypes::{Field, Schema};
+use std::collections::HashSet;
+
+/// The on clause of the join, as vector of (left, right) columns.
+pub type JoinOn = Vec<(Column, Column)>;
+/// Reference for JoinOn.
+pub type JoinOnRef<'a> = &'a [(Column, Column)];
+
+/// Checks whether the schemas "left" and "right" and columns "on" represent a valid join.
+/// They are valid whenever their columns' intersection equals the set `on`
+pub fn check_join_is_valid(left: &Schema, right: &Schema, on: JoinOnRef) -> Result<()> {
+ let left: HashSet<Column> = left
+ .fields()
+ .iter()
+ .enumerate()
+ .map(|(idx, f)| Column::new(f.name(), idx))
+ .collect();
+ let right: HashSet<Column> = right
+ .fields()
+ .iter()
+ .enumerate()
+ .map(|(idx, f)| Column::new(f.name(), idx))
+ .collect();
+
+ check_join_set_is_valid(&left, &right, on)
+}
+
+/// Checks whether the sets left, right and on compose a valid join.
+/// They are valid whenever their intersection equals the set `on`
+fn check_join_set_is_valid(
+ left: &HashSet<Column>,
+ right: &HashSet<Column>,
+ on: &[(Column, Column)],
+) -> Result<()> {
+ let on_left = &on.iter().map(|on| on.0.clone()).collect::<HashSet<_>>();
+ let left_missing = on_left.difference(left).collect::<HashSet<_>>();
+
+ let on_right = &on.iter().map(|on| on.1.clone()).collect::<HashSet<_>>();
+ let right_missing = on_right.difference(right).collect::<HashSet<_>>();
+
+ if !left_missing.is_empty() | !right_missing.is_empty() {
+ return Err(DataFusionError::Plan(format!(
+ "The left or right side of the join does not have all columns on \"on\": \nMissing on the left: {:?}\nMissing on the right: {:?}",
+ left_missing,
+ right_missing,
+ )));
+ };
+
+ Ok(())
+}
+
+/// Used in ColumnIndex to distinguish which side the index is for
+#[derive(Debug, Clone)]
+pub enum JoinSide {
+ /// Left side of the join
+ Left,
+ /// Right side of the join
+ Right,
+}
+
+/// Information about the index and placement (left or right) of the columns
+#[derive(Debug, Clone)]
+pub struct ColumnIndex {
+ /// Index of the column
+ pub index: usize,
+ /// Whether the column is at the left or right side
+ pub side: JoinSide,
+}
+
+/// Creates a schema for a join operation.
+/// The fields from the left side are first
+pub fn build_join_schema(
+ left: &Schema,
+ right: &Schema,
+ join_type: &JoinType,
+) -> (Schema, Vec<ColumnIndex>) {
+ let (fields, column_indices): (Vec<Field>, Vec<ColumnIndex>) = match join_type {
+ JoinType::Inner | JoinType::Left | JoinType::Full | JoinType::Right => {
+ let left_fields =
+ left.fields().iter().cloned().enumerate().map(|(index, f)| {
+ (
+ f,
+ ColumnIndex {
+ index,
+ side: JoinSide::Left,
+ },
+ )
+ });
+ let right_fields =
+ right
+ .fields()
+ .iter()
+ .cloned()
+ .enumerate()
+ .map(|(index, f)| {
+ (
+ f,
+ ColumnIndex {
+ index,
+ side: JoinSide::Right,
+ },
+ )
+ });
+
+ // left then right
+ left_fields.chain(right_fields).unzip()
+ }
+ JoinType::Semi | JoinType::Anti => left
+ .fields()
+ .iter()
+ .cloned()
+ .enumerate()
+ .map(|(index, f)| {
+ (
+ f,
+ ColumnIndex {
+ index,
+ side: JoinSide::Left,
+ },
+ )
+ })
+ .unzip(),
+ };
+
+ (Schema::new(fields), column_indices)
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+
+ fn check(left: &[Column], right: &[Column], on: &[(Column, Column)]) -> Result<()> {
+ let left = left
+ .iter()
+ .map(|x| x.to_owned())
+ .collect::<HashSet<Column>>();
+ let right = right
+ .iter()
+ .map(|x| x.to_owned())
+ .collect::<HashSet<Column>>();
+ check_join_set_is_valid(&left, &right, on)
+ }
+
+ #[test]
+ fn check_valid() -> Result<()> {
+ let left = vec![Column::new("a", 0), Column::new("b1", 1)];
+ let right = vec![Column::new("a", 0), Column::new("b2", 1)];
+ let on = &[(Column::new("a", 0), Column::new("a", 0))];
+
+ check(&left, &right, on)?;
+ Ok(())
+ }
+
+ #[test]
+ fn check_not_in_right() {
+ let left = vec![Column::new("a", 0), Column::new("b", 1)];
+ let right = vec![Column::new("b", 0)];
+ let on = &[(Column::new("a", 0), Column::new("a", 0))];
+
+ assert!(check(&left, &right, on).is_err());
+ }
+
+ #[test]
+ fn check_not_in_left() {
+ let left = vec![Column::new("b", 0)];
+ let right = vec![Column::new("a", 0)];
+ let on = &[(Column::new("a", 0), Column::new("a", 0))];
+
+ assert!(check(&left, &right, on).is_err());
+ }
+
+ #[test]
+ fn check_collision() {
+ // column "a" would appear both in left and right
+ let left = vec![Column::new("a", 0), Column::new("c", 1)];
+ let right = vec![Column::new("a", 0), Column::new("b", 1)];
+ let on = &[(Column::new("a", 0), Column::new("b", 1))];
+
+ assert!(check(&left, &right, on).is_ok());
+ }
+
+ #[test]
+ fn check_in_right() {
+ let left = vec![Column::new("a", 0), Column::new("c", 1)];
+ let right = vec![Column::new("b", 0)];
+ let on = &[(Column::new("a", 0), Column::new("b", 0))];
+
+ assert!(check(&left, &right, on).is_ok());
+ }
+}
diff --git a/datafusion/src/physical_plan/mod.rs b/datafusion/src/physical_plan/mod.rs
index 3701e90..d12b217 100644
--- a/datafusion/src/physical_plan/mod.rs
+++ b/datafusion/src/physical_plan/mod.rs
@@ -622,6 +622,7 @@ pub mod functions;
pub mod hash_aggregate;
pub mod hash_join;
pub mod hash_utils;
+pub mod join_utils;
pub mod json;
pub mod limit;
pub mod math_expressions;
diff --git a/datafusion/src/physical_plan/planner.rs b/datafusion/src/physical_plan/planner.rs
index 0ff5958..55dc936 100644
--- a/datafusion/src/physical_plan/planner.rs
+++ b/datafusion/src/physical_plan/planner.rs
@@ -42,7 +42,7 @@ use crate::physical_plan::repartition::RepartitionExec;
use crate::physical_plan::sort::SortExec;
use crate::physical_plan::udf;
use crate::physical_plan::windows::WindowAggExec;
-use crate::physical_plan::{hash_utils, Partitioning};
+use crate::physical_plan::{join_utils, Partitioning};
use crate::physical_plan::{AggregateExpr, ExecutionPlan, PhysicalExpr, WindowExpr};
use crate::scalar::ScalarValue;
use crate::sql::utils::{generate_sort_key, window_expr_common_partition_keys};
@@ -676,7 +676,7 @@ impl DefaultPhysicalPlanner {
Column::new(&r.name, right_df_schema.index_of_column(r)?),
))
})
- .collect::<Result<hash_utils::JoinOn>>()?;
+ .collect::<Result<join_utils::JoinOn>>()?;
if ctx_state.config.target_partitions > 1
&& ctx_state.config.repartition_joins
diff --git a/datafusion/tests/sql.rs b/datafusion/tests/sql.rs
index 1cc903b..4cd0ed3 100644
--- a/datafusion/tests/sql.rs
+++ b/datafusion/tests/sql.rs
@@ -4716,6 +4716,69 @@ async fn test_regexp_is_match() -> Result<()> {
Ok(())
}
+#[tokio::test]
+async fn join_tables_with_duplicated_column_name_not_in_on_constraint() -> Result<()> {
+ let batch = RecordBatch::try_from_iter(vec![
+ ("id", Arc::new(Int32Array::from(vec![1, 2, 3])) as _),
+ (
+ "country",
+ Arc::new(StringArray::from(vec!["Germany", "Sweden", "Japan"])) as _,
+ ),
+ ])
+ .unwrap();
+ let countries = MemTable::try_new(batch.schema(), vec![vec![batch]])?;
+
+ let batch = RecordBatch::try_from_iter(vec![
+ (
+ "id",
+ Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5, 6, 7])) as _,
+ ),
+ (
+ "city",
+ Arc::new(StringArray::from(vec![
+ "Hamburg",
+ "Stockholm",
+ "Osaka",
+ "Berlin",
+ "Göteborg",
+ "Tokyo",
+ "Kyoto",
+ ])) as _,
+ ),
+ (
+ "country_id",
+ Arc::new(Int32Array::from(vec![1, 2, 3, 1, 2, 3, 3])) as _,
+ ),
+ ])
+ .unwrap();
+ let cities = MemTable::try_new(batch.schema(), vec![vec![batch]])?;
+
+ let mut ctx = ExecutionContext::new();
+ ctx.register_table("countries", Arc::new(countries))?;
+ ctx.register_table("cities", Arc::new(cities))?;
+
+ // city.id is not in the on constraint, but the output result will contain both city.id and
+ // country.id
+ let sql = "SELECT t1.id, t2.id, t1.city, t2.country FROM cities AS t1 JOIN countries AS t2 ON t1.country_id = t2.id ORDER BY t1.id";
+ let actual = execute_to_batches(&mut ctx, sql).await;
+ let expected = vec![
+ "+----+----+-----------+---------+",
+ "| id | id | city | country |",
+ "+----+----+-----------+---------+",
+ "| 1 | 1 | Hamburg | Germany |",
+ "| 2 | 2 | Stockholm | Sweden |",
+ "| 3 | 3 | Osaka | Japan |",
+ "| 4 | 1 | Berlin | Germany |",
+ "| 5 | 2 | Göteborg | Sweden |",
+ "| 6 | 3 | Tokyo | Japan |",
+ "| 7 | 3 | Kyoto | Japan |",
+ "+----+----+-----------+---------+",
+ ];
+
+ assert_batches_eq!(expected, &actual);
+ Ok(())
+}
+
#[cfg(feature = "avro")]
#[tokio::test]
async fn avro_query() {