You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by su...@apache.org on 2023/02/06 18:05:08 UTC

[spark] branch master updated: [SPARK-41470][SQL] SPJ: Relax constraints on Storage-Partitioned-Join should assume InternalRow implements equals and hashCode

This is an automated email from the ASF dual-hosted git repository.

sunchao pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 97a20edb258 [SPARK-41470][SQL] SPJ: Relax constraints on Storage-Partitioned-Join should assume InternalRow implements equals and hashCode
97a20edb258 is described below

commit 97a20edb25850015695ffe1a7246df55951ebd35
Author: chenliang.lu <ma...@gmail.com>
AuthorDate: Mon Feb 6 10:04:45 2023 -0800

    [SPARK-41470][SQL] SPJ: Relax constraints on Storage-Partitioned-Join should assume InternalRow implements equals and hashCode
    
    ### What changes were proposed in this pull request?
    Introduce a new wrapper class for comparable InternalRow (returned by `HasPartitionKey`, with datatype) and remove `InternalRowSet` for easy `groupBy`, `Set`, `Map` and other operations.
    
    ### Why are the changes needed?
    Currently SPJ (Storage-Partitioned Join) actually assumes the `InternalRow` returned by `HasPartitionKey` implements equals and hashCode. We should remove this restriction.
    For example, see [comments](https://github.com/apache/iceberg/pull/6371/files#r1056852402) in Iceberg  [StructInternalRow](https://github.com/apache/iceberg/blob/master/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/StructInternalRow.java#L362) should implements equals and hashCode. see . Actually it is not necessary.
    
    ### Does this PR introduce any user-facing change?
    No
    
    ### How was this patch tested?
    existing tests
    
    Closes #39687 from yabola/InternalRow_hashcode.
    
    Authored-by: chenliang.lu <ma...@gmail.com>
    Signed-off-by: Chao Sun <su...@apple.com>
---
 .../spark/sql/catalyst/util/InternalRowSet.scala   | 65 -----------------
 .../spark/sql/catalyst/util/InternalRowSet.scala   | 69 ------------------
 .../sql/catalyst/plans/physical/partitioning.scala |  8 +--
 .../util/InternalRowComparableWrapper.scala        | 84 ++++++++++++++++++++++
 .../sql/connector/catalog/InMemoryBaseTable.scala  | 25 ++++++-
 .../execution/datasources/v2/BatchScanExec.scala   | 41 +++++------
 .../datasources/v2/DataSourceV2ScanExecBase.scala  | 21 +++---
 .../execution/exchange/EnsureRequirements.scala    | 11 ++-
 .../spark/sql/connector/DataSourceV2Suite.scala    |  4 +-
 9 files changed, 149 insertions(+), 179 deletions(-)

diff --git a/sql/catalyst/src/main/scala-2.12/org/apache/spark/sql/catalyst/util/InternalRowSet.scala b/sql/catalyst/src/main/scala-2.12/org/apache/spark/sql/catalyst/util/InternalRowSet.scala
deleted file mode 100644
index 9e8ec042694..00000000000
--- a/sql/catalyst/src/main/scala-2.12/org/apache/spark/sql/catalyst/util/InternalRowSet.scala
+++ /dev/null
@@ -1,65 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.catalyst.util
-
-import scala.collection.mutable
-
-import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions.{Murmur3HashFunction, RowOrdering}
-import org.apache.spark.sql.types.{DataType, StructField, StructType}
-
-/**
- * A mutable Set with [[InternalRow]] as its element type. It uses Spark's internal murmur hash to
- * compute hash code from an row, and uses [[RowOrdering]] to perform equality checks.
- *
- * @param dataTypes the data types for the row keys this set holds
- */
-class InternalRowSet(val dataTypes: Seq[DataType]) extends mutable.Set[InternalRow] {
-  private val baseSet = new mutable.HashSet[InternalRowContainer]
-
-  private val structType = StructType(dataTypes.map(t => StructField("f", t)))
-  private val ordering = RowOrdering.createNaturalAscendingOrdering(dataTypes)
-
-  override def contains(row: InternalRow): Boolean =
-    baseSet.contains(new InternalRowContainer(row))
-
-  private class InternalRowContainer(val row: InternalRow) {
-    override def hashCode(): Int = Murmur3HashFunction.hash(row, structType, 42L).toInt
-
-    override def equals(other: Any): Boolean = other match {
-      case r: InternalRowContainer => ordering.compare(row, r.row) == 0
-      case r => this == r
-    }
-  }
-
-  override def +=(row: InternalRow): InternalRowSet.this.type = {
-    val rowKey = new InternalRowContainer(row)
-    baseSet += rowKey
-    this
-  }
-
-  override def -=(row: InternalRow): InternalRowSet.this.type = {
-    val rowKey = new InternalRowContainer(row)
-    baseSet -= rowKey
-    this
-  }
-
-  override def iterator: Iterator[InternalRow] = {
-    baseSet.iterator.map(_.row)
-  }
-}
diff --git a/sql/catalyst/src/main/scala-2.13/org/apache/spark/sql/catalyst/util/InternalRowSet.scala b/sql/catalyst/src/main/scala-2.13/org/apache/spark/sql/catalyst/util/InternalRowSet.scala
deleted file mode 100644
index 66090fdf187..00000000000
--- a/sql/catalyst/src/main/scala-2.13/org/apache/spark/sql/catalyst/util/InternalRowSet.scala
+++ /dev/null
@@ -1,69 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.catalyst.util
-
-import scala.collection.mutable
-
-import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions.{Murmur3HashFunction, RowOrdering}
-import org.apache.spark.sql.types.{DataType, StructField, StructType}
-
-/**
- * A mutable Set with [[InternalRow]] as its element type. It uses Spark's internal murmur hash to
- * compute hash code from an row, and uses [[RowOrdering]] to perform equality checks.
- *
- * @param dataTypes the data types for the row keys this set holds
- */
-class InternalRowSet(val dataTypes: Seq[DataType]) extends mutable.Set[InternalRow] {
-  private val baseSet = new mutable.HashSet[InternalRowContainer]
-
-  private val structType = StructType(dataTypes.map(t => StructField("f", t)))
-  private val ordering = RowOrdering.createNaturalAscendingOrdering(dataTypes)
-
-  override def contains(row: InternalRow): Boolean =
-    baseSet.contains(new InternalRowContainer(row))
-
-  private class InternalRowContainer(val row: InternalRow) {
-    override def hashCode(): Int = Murmur3HashFunction.hash(row, structType, 42L).toInt
-
-    override def equals(other: Any): Boolean = other match {
-      case r: InternalRowContainer => ordering.compare(row, r.row) == 0
-      case r => this == r
-    }
-  }
-
-  override def addOne(row: InternalRow): InternalRowSet.this.type = {
-    val rowKey = new InternalRowContainer(row)
-    baseSet += rowKey
-    this
-  }
-
-  override def subtractOne(row: InternalRow): InternalRowSet.this.type = {
-    val rowKey = new InternalRowContainer(row)
-    baseSet -= rowKey
-    this
-  }
-
-  override def clear(): Unit = {
-    baseSet.clear()
-  }
-
-  override def iterator: Iterator[InternalRow] = {
-    baseSet.iterator.map(_.row)
-  }
-}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
index 73d39a19243..6512344169b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
@@ -22,6 +22,7 @@ import scala.collection.mutable
 
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.util.InternalRowComparableWrapper
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.types.{DataType, IntegerType}
 
@@ -677,9 +678,6 @@ case class KeyGroupedShuffleSpec(
     }
   }
 
-  lazy val ordering: Ordering[InternalRow] =
-    RowOrdering.createNaturalAscendingOrdering(partitioning.expressions.map(_.dataType))
-
   override def numPartitions: Int = partitioning.numPartitions
 
   override def isCompatibleWith(other: ShuffleSpec): Boolean = other match {
@@ -697,7 +695,9 @@ case class KeyGroupedShuffleSpec(
       distribution.clustering.length == otherDistribution.clustering.length &&
         numPartitions == other.numPartitions && areKeysCompatible(otherSpec) &&
           partitioning.partitionValues.zip(otherPartitioning.partitionValues).forall {
-            case (left, right) => ordering.compare(left, right) == 0
+            case (left, right) =>
+              InternalRowComparableWrapper(left, partitioning.expressions)
+                .equals(InternalRowComparableWrapper(right, partitioning.expressions))
           }
     case ShuffleSpecCollection(specs) =>
       specs.exists(isCompatibleWith)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/InternalRowComparableWrapper.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/InternalRowComparableWrapper.scala
new file mode 100644
index 00000000000..b0e53090731
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/InternalRowComparableWrapper.scala
@@ -0,0 +1,84 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.util
+
+import scala.collection.mutable
+
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.{Expression, Murmur3HashFunction, RowOrdering}
+import org.apache.spark.sql.catalyst.plans.physical.KeyGroupedPartitioning
+import org.apache.spark.sql.connector.read.{HasPartitionKey, InputPartition}
+import org.apache.spark.sql.types.{DataType, StructField, StructType}
+
+/**
+ * Wraps the [[InternalRow]] with the corresponding [[DataType]] to make it comparable with
+ * the values in [[InternalRow]].
+ * It uses Spark's internal murmur hash to compute hash code from an row, and uses [[RowOrdering]]
+ * to perform equality checks.
+ *
+ * @param dataTypes the data types for the row
+ */
+class InternalRowComparableWrapper(val row: InternalRow, val dataTypes: Seq[DataType]) {
+
+  private val structType = StructType(dataTypes.map(t => StructField("f", t)))
+  private val ordering = RowOrdering.createNaturalAscendingOrdering(dataTypes)
+
+  override def hashCode(): Int = Murmur3HashFunction.hash(row, structType, 42L).toInt
+
+  override def equals(other: Any): Boolean = {
+    if (!other.isInstanceOf[InternalRowComparableWrapper]) {
+      return false
+    }
+    val otherWrapper = other.asInstanceOf[InternalRowComparableWrapper]
+    if (!otherWrapper.dataTypes.equals(this.dataTypes)) {
+      return false
+    }
+    ordering.compare(row, otherWrapper.row) == 0
+  }
+}
+
+object InternalRowComparableWrapper {
+
+  def apply(
+      partition: InputPartition with HasPartitionKey,
+      partitionExpression: Seq[Expression]): InternalRowComparableWrapper = {
+    new InternalRowComparableWrapper(
+      partition.asInstanceOf[HasPartitionKey].partitionKey(), partitionExpression.map(_.dataType))
+  }
+
+  def apply(
+      partitionRow: InternalRow,
+      partitionExpression: Seq[Expression]): InternalRowComparableWrapper = {
+    new InternalRowComparableWrapper(partitionRow, partitionExpression.map(_.dataType))
+  }
+
+  def mergePartitions(
+      leftPartitioning: KeyGroupedPartitioning,
+      rightPartitioning: KeyGroupedPartitioning,
+      partitionExpression: Seq[Expression]): Seq[InternalRow] = {
+    val partitionDataTypes = partitionExpression.map(_.dataType)
+    val partitionsSet = new mutable.HashSet[InternalRowComparableWrapper]
+    leftPartitioning.partitionValues
+      .map(new InternalRowComparableWrapper(_, partitionDataTypes))
+      .foreach(partition => partitionsSet.add(partition))
+    rightPartitioning.partitionValues
+      .map(new InternalRowComparableWrapper(_, partitionDataTypes))
+      .foreach(partition => partitionsSet.add(partition))
+    partitionsSet.map(_.row).toSeq
+  }
+}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala
index 28b68a71a47..1f7dd1b3092 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala
@@ -24,6 +24,7 @@ import java.util.OptionalLong
 
 import scala.collection.mutable
 
+import com.google.common.base.Objects
 import org.scalatest.Assertions._
 
 import org.apache.spark.sql.catalyst.InternalRow
@@ -541,13 +542,31 @@ class BufferedRows(val key: Seq[Any] = Seq.empty) extends WriterCommitMessage
 
   def keyString(): String = key.toArray.mkString("/")
 
-  override def partitionKey(): InternalRow = {
-    InternalRow.fromSeq(key)
-  }
+  override def partitionKey(): InternalRow = PartitionInternalRow(key.toArray)
 
   def clear(): Unit = rows.clear()
 }
 
+/**
+ * Theoretically, [[InternalRow]] returned by [[HasPartitionKey#partitionKey()]]
+ * does not need to implement equal and hashcode methods.
+ * But [[GenericInternalRow]] implements equals and hashcode methods already. Here we override it
+ * to simulate that it has not been implemented to verify codes correctness.
+ */
+case class PartitionInternalRow(keys: Array[Any])
+  extends GenericInternalRow(keys) {
+  override def equals(other: Any): Boolean = {
+    if (!other.isInstanceOf[PartitionInternalRow]) {
+      return false
+    }
+    // Just compare by reference, not by value
+    this.keys == other.asInstanceOf[PartitionInternalRow].keys
+  }
+  override def hashCode: Int = {
+    Objects.hashCode(keys)
+  }
+}
+
 private class BufferedRowsReaderFactory(
     metadataColumnNames: Seq[String],
     nonMetaDataColumns: Seq[StructField],
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala
index d6b76ae1096..93b35337b52 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala
@@ -25,10 +25,9 @@ import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.plans.QueryPlan
 import org.apache.spark.sql.catalyst.plans.physical.{KeyGroupedPartitioning, Partitioning, SinglePartition}
-import org.apache.spark.sql.catalyst.util.InternalRowSet
-import org.apache.spark.sql.catalyst.util.truncatedString
+import org.apache.spark.sql.catalyst.util.{truncatedString, InternalRowComparableWrapper}
 import org.apache.spark.sql.connector.catalog.Table
-import org.apache.spark.sql.connector.read.{HasPartitionKey, InputPartition, PartitionReaderFactory, Scan, SupportsRuntimeV2Filtering}
+import org.apache.spark.sql.connector.read._
 
 /**
  * Physical plan node for scanning a batch of data from a data source v2.
@@ -80,24 +79,24 @@ case class BatchScanExec(
                 "during runtime filtering: not all partitions implement HasPartitionKey after " +
                 "filtering")
           }
-
-          val newRows = new InternalRowSet(p.expressions.map(_.dataType))
-          newRows ++= newPartitions.map(_.asInstanceOf[HasPartitionKey].partitionKey())
-
-          val oldRows = p.partitionValues.toSet
-          // We require the new number of partition keys to be equal or less than the old number
-          // of partition keys here. In the case of less than, empty partitions will be added for
-          // those missing keys that are not present in the new input partitions.
-          if (oldRows.size < newRows.size) {
+          val newPartitionValues = newPartitions.map(partition =>
+              InternalRowComparableWrapper(partition.asInstanceOf[HasPartitionKey], p.expressions))
+            .toSet
+          val oldPartitionValues = p.partitionValues
+            .map(partition => InternalRowComparableWrapper(partition, p.expressions)).toSet
+          // We require the new number of partition values to be equal or less than the old number
+          // of partition values here. In the case of less than, empty partitions will be added for
+          // those missing values that are not present in the new input partitions.
+          if (oldPartitionValues.size < newPartitionValues.size) {
             throw new SparkException("During runtime filtering, data source must either report " +
-                "the same number of partition keys, or a subset of partition keys from the " +
-                s"original. Before: ${oldRows.size} partition keys. After: ${newRows.size} " +
-                "partition keys")
+                "the same number of partition values, or a subset of partition values from the " +
+                s"original. Before: ${oldPartitionValues.size} partition values. " +
+                s"After: ${newPartitionValues.size} partition values")
           }
 
-          if (!newRows.forall(oldRows.contains)) {
+          if (!newPartitionValues.forall(oldPartitionValues.contains)) {
             throw new SparkException("During runtime filtering, data source must not report new " +
-                "partition keys that are not present in the original partitioning.")
+                "partition values that are not present in the original partitioning.")
           }
 
           groupPartitions(newPartitions).get.map(_._2)
@@ -132,11 +131,13 @@ case class BatchScanExec(
 
       outputPartitioning match {
         case p: KeyGroupedPartitioning =>
-          val partitionMapping = finalPartitions.map(s =>
-            s.head.asInstanceOf[HasPartitionKey].partitionKey() -> s).toMap
+          val partitionMapping = finalPartitions.map(s => InternalRowComparableWrapper(
+            s.head.asInstanceOf[HasPartitionKey], p.expressions) -> s)
+            .toMap
           finalPartitions = p.partitionValues.map { partValue =>
             // Use empty partition for those partition values that are not present
-            partitionMapping.getOrElse(partValue, Seq.empty)
+            partitionMapping.getOrElse(
+              InternalRowComparableWrapper(partValue, p.expressions), Seq.empty)
           }
         case _ =>
       }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExecBase.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExecBase.scala
index 556ae4afb63..8a7c4729a0e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExecBase.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExecBase.scala
@@ -22,7 +22,7 @@ import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions.{Expression, RowOrdering, SortOrder}
 import org.apache.spark.sql.catalyst.plans.physical
 import org.apache.spark.sql.catalyst.plans.physical.{KeyGroupedPartitioning, SinglePartition}
-import org.apache.spark.sql.catalyst.util.truncatedString
+import org.apache.spark.sql.catalyst.util.{truncatedString, InternalRowComparableWrapper}
 import org.apache.spark.sql.connector.read.{HasPartitionKey, InputPartition, PartitionReaderFactory, Scan}
 import org.apache.spark.sql.execution.{ExplainUtils, LeafExecNode, SQLExecution}
 import org.apache.spark.sql.execution.metric.SQLMetrics
@@ -133,18 +133,21 @@ trait DataSourceV2ScanExecBase extends LeafExecNode {
         // Not all of the `InputPartitions` implements `HasPartitionKey`, therefore skip here.
         None
       } else {
-        val partKeyType = expressions.map(_.dataType)
-
-        val groupedPartitions = results.groupBy(_._1).toSeq.map { case (key, s) =>
-          (key, s.map(_._2))
-        }
+        val groupedPartitions = inputPartitions.map(p =>
+            (InternalRowComparableWrapper(p.asInstanceOf[HasPartitionKey], expressions), p))
+          .groupBy(_._1)
+          .toSeq
+          .map {
+            case (key, s) => (key.row, s.map(_._2))
+          }
 
         // also sort the input partitions according to their partition key order. This ensures
         // a canonical order from both sides of a bucketed join, for example.
-        val keyOrdering: Ordering[(InternalRow, Seq[InputPartition])] = {
-          RowOrdering.createNaturalAscendingOrdering(partKeyType).on(_._1)
+        val partitionDataTypes = expressions.map(_.dataType)
+        val partitionOrdering: Ordering[(InternalRow, Seq[InputPartition])] = {
+          RowOrdering.createNaturalAscendingOrdering(partitionDataTypes).on(_._1)
         }
-        Some(groupedPartitions.sorted(keyOrdering))
+        Some(groupedPartitions.sorted(partitionOrdering))
       }
     }
   }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala
index f88436297e7..4b229fcbfcd 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala
@@ -24,11 +24,11 @@ import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.plans.physical._
 import org.apache.spark.sql.catalyst.rules.Rule
+import org.apache.spark.sql.catalyst.util.InternalRowComparableWrapper
 import org.apache.spark.sql.execution._
 import org.apache.spark.sql.execution.datasources.v2.BatchScanExec
 import org.apache.spark.sql.execution.joins.{ShuffledHashJoinExec, SortMergeJoinExec}
 import org.apache.spark.sql.internal.SQLConf
-import org.apache.spark.util.collection.Utils
 
 /**
  * Ensures that the [[org.apache.spark.sql.catalyst.plans.physical.Partitioning Partitioning]]
@@ -198,12 +198,9 @@ case class EnsureRequirements(
                 // Check if the two children are partition keys compatible. If so, find the
                 // common set of partition values, and adjust the plan accordingly.
                 if (leftSpec.areKeysCompatible(rightSpec)) {
-                  val leftPartValues = leftSpec.partitioning.partitionValues
-                  val rightPartValues = rightSpec.partitioning.partitionValues
-
-                  val mergedPartValues = Utils.mergeOrdered(
-                    Seq(leftPartValues, rightPartValues))(leftSpec.ordering).toSeq.distinct
-
+                  val mergedPartValues = InternalRowComparableWrapper.mergePartitions(
+                    leftSpec.partitioning, rightSpec.partitioning,
+                    leftSpec.partitioning.expressions)
                   // Now we need to push-down the common partition key to the scan in each child
                   children = children.zipWithIndex.map {
                     case (child, idx) if childrenIndexes.contains(idx) =>
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala
index 5c4be75e02c..02990a7a40d 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala
@@ -24,7 +24,7 @@ import test.org.apache.spark.sql.connector._
 
 import org.apache.spark.sql.{AnalysisException, DataFrame, QueryTest, Row}
 import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.connector.catalog.{SupportsRead, Table, TableCapability, TableProvider}
+import org.apache.spark.sql.connector.catalog.{PartitionInternalRow, SupportsRead, Table, TableCapability, TableProvider}
 import org.apache.spark.sql.connector.catalog.TableCapability._
 import org.apache.spark.sql.connector.expressions.{Expression, FieldReference, Literal, NamedReference, NullOrdering, SortDirection, SortOrder, Transform}
 import org.apache.spark.sql.connector.expressions.filter.Predicate
@@ -1037,7 +1037,7 @@ class OrderAndPartitionAwareDataSource extends PartitionAwareDataSource {
 case class SpecificInputPartition(
     i: Array[Int],
     j: Array[Int]) extends InputPartition with HasPartitionKey {
-  override def partitionKey(): InternalRow = InternalRow.fromSeq(Seq(i(0)))
+  override def partitionKey(): InternalRow = PartitionInternalRow(Seq(i(0)).toArray)
 }
 
 object SpecificReaderFactory extends PartitionReaderFactory {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org