You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by do...@apache.org on 2023/02/07 05:20:58 UTC

[spark] branch branch-3.4 updated: [SPARK-42038][SQL] SPJ: Support partially clustered distribution

This is an automated email from the ASF dual-hosted git repository.

dongjoon pushed a commit to branch branch-3.4
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.4 by this push:
     new 9a010ebfaf6 [SPARK-42038][SQL] SPJ: Support partially clustered distribution
9a010ebfaf6 is described below

commit 9a010ebfaf6eaeb2cff1242464251316da887ef1
Author: Chao Sun <su...@apple.com>
AuthorDate: Mon Feb 6 21:20:37 2023 -0800

    [SPARK-42038][SQL] SPJ: Support partially clustered distribution
    
    ### What changes were proposed in this pull request?
    
    Currently with [storage-partitioned join](https://docs.google.com/document/d/1foTkDSM91VxKgkEcBMsuAvEjNybjja-uHk-r3vtXWFE/edit#heading=h.82w8qxfl2uwl), both sides of the join must be **fully clustered** over the partition values, that is, each Spark partition should have a distinct partition value. To guarantee this, Spark will group all the input partitions reported by a V2 data source on the partition values. The consequence, however, is that it can easily lead to data skew, when a  [...]
    
    This PR introduce a new mechanism which requires only one side of the storage-partitioned join to be fully clustered, while the other side can be **partially clustered**, i.e., there could exist multiple Spark partitions with the identical partition value. At planning time, Spark will compare the statistics from both sides of the join, and pick the side with smaller size as full clustered side, while the other side is partially clustered. It then replicate the partitions on the former [...]
    
    The concept of this optimization is similar to the existing `OptimizeSkewedJoin`, and techniques such as key salting, but is applied on DataSource V2 level instead, and doesn't require AQE to be enabled. Unlike `OptimizeSkewedJoin`, however, this optimization is applied before any shuffle happens.
    
    A new config `spark.sql.sources.v2.bucketing.partiallyClusteredDistribution.enabled` is added to enable & disable the feature. By default it is turned off.
    
    For instance, consider the following SQL query:
    ```sql
    SELECT * FROM a JOIN b on a.id = b.id
    ```
    
    while table `a` reports partitions `[0, 1, 2]` while `b` reports partitions `[0, 1, 1, 1, 2, 2]`.
    
    Without the PR, Spark would group input partitions so that both sides will have 3 partitions `[0, 1, 2]`. With the PR, Spark will choose the right-hand side as partially clustered and match the left-hand side with it. Therefore, both sides have 6 partitions `[0, 1, 1, 1, 2, 2]`.
    
    Note this PR currently relies on a simple heuristic and always pick the side with less data size based on table statistics as the side fully clustered, even though it could also contain skewed partitions. In future, we can potentially do fine-grained comparison based on partition values.
    
    ### Why are the changes needed?
    
    As mentioned in the previous section, this feature can help to reduce data skewness during storage-partitioned join, when a few partitions are mapped to a large amount of data.
    
    ### Does this PR introduce _any_ user-facing change?
    
    A new Spark config, `spark.sql.sources.v2.bucketing.partiallyClusteredDistribution.enabled`, is introduced to enable/disable the feature. By default it is disabled, so the behavior should still be the same as before.
    
    ### How was this patch tested?
    
    Added new tests in `KeyGroupedPartitioningSuite`.
    
    Closes #39633 from sunchao/SPARK-42038.
    
    Authored-by: Chao Sun <su...@apple.com>
    Signed-off-by: Dongjoon Hyun <do...@apache.org>
    (cherry picked from commit cf3c02e0e0976824e0497abd8d612bb587608432)
    Signed-off-by: Dongjoon Hyun <do...@apache.org>
---
 .../apache/spark/sql/avro/AvroRowReaderSuite.scala |   2 +-
 .../org/apache/spark/sql/avro/AvroSuite.scala      |   6 +-
 .../org/apache/spark/sql/internal/SQLConf.scala    |  16 +
 .../sql/connector/catalog/InMemoryBaseTable.scala  |  61 ++-
 .../sql/connector/catalog/InMemoryTable.scala      |  17 +-
 .../connector/catalog/InMemoryTableCatalog.scala   |   5 +-
 .../execution/datasources/v2/BatchScanExec.scala   |  96 ++++-
 .../datasources/v2/DataSourceV2ScanExecBase.scala  |  31 +-
 .../execution/exchange/EnsureRequirements.scala    | 338 +++++++++++----
 .../apache/spark/sql/DataFrameWriterV2Suite.scala  |   4 +-
 .../spark/sql/FileBasedDataSourceSuite.scala       |   4 +-
 .../connector/KeyGroupedPartitioningSuite.scala    | 461 ++++++++++++++++++++-
 .../PruneFileSourcePartitionsSuite.scala           |   2 +-
 .../datasources/PrunePartitionSuiteBase.scala      |   2 +-
 .../datasources/orc/OrcV2SchemaPruningSuite.scala  |   2 +-
 .../exchange/EnsureRequirementsSuite.scala         |   2 +-
 16 files changed, 900 insertions(+), 149 deletions(-)

diff --git a/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroRowReaderSuite.scala b/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroRowReaderSuite.scala
index 529b78c3b7f..046ff4ef088 100644
--- a/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroRowReaderSuite.scala
+++ b/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroRowReaderSuite.scala
@@ -59,7 +59,7 @@ class AvroRowReaderSuite
 
       val df = spark.read.format("avro").load(dir.getCanonicalPath)
       val fileScan = df.queryExecution.executedPlan collectFirst {
-        case BatchScanExec(_, f: AvroScan, _, _, _, _, _) => f
+        case BatchScanExec(_, f: AvroScan, _, _, _, _, _, _, _) => f
       }
       val filePath = fileScan.get.fileIndex.inputFiles(0)
       val fileSize = new File(new URI(filePath)).length
diff --git a/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala b/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala
index d4e85addf95..a913da7a172 100644
--- a/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala
+++ b/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala
@@ -2350,7 +2350,7 @@ class AvroV2Suite extends AvroSuite with ExplainSuiteHelper {
       })
 
       val fileScan = df.queryExecution.executedPlan collectFirst {
-        case BatchScanExec(_, f: AvroScan, _, _, _, _, _) => f
+        case BatchScanExec(_, f: AvroScan, _, _, _, _, _, _, _) => f
       }
       assert(fileScan.nonEmpty)
       assert(fileScan.get.partitionFilters.nonEmpty)
@@ -2384,7 +2384,7 @@ class AvroV2Suite extends AvroSuite with ExplainSuiteHelper {
       assert(filterCondition.isDefined)
 
       val fileScan = df.queryExecution.executedPlan collectFirst {
-        case BatchScanExec(_, f: AvroScan, _, _, _, _, _) => f
+        case BatchScanExec(_, f: AvroScan, _, _, _, _, _, _, _) => f
       }
       assert(fileScan.nonEmpty)
       assert(fileScan.get.partitionFilters.isEmpty)
@@ -2465,7 +2465,7 @@ class AvroV2Suite extends AvroSuite with ExplainSuiteHelper {
             .where("value = 'a'")
 
           val fileScan = df.queryExecution.executedPlan collectFirst {
-            case BatchScanExec(_, f: AvroScan, _, _, _, _, _) => f
+            case BatchScanExec(_, f: AvroScan, _, _, _, _, _, _, _) => f
           }
           assert(fileScan.nonEmpty)
           if (filtersPushdown) {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
index ecc35850bf0..abe47c3720d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
@@ -1453,6 +1453,19 @@ object SQLConf {
       .booleanConf
       .createWithDefault(false)
 
+  val V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED =
+    buildConf("spark.sql.sources.v2.bucketing.partiallyClusteredDistribution.enabled")
+      .doc("During a storage-partitioned join, whether to allow input partitions to be " +
+        "partially clustered, when both sides of the join are of KeyGroupedPartitioning. At " +
+        "planning time, Spark will pick the side with less data size based on table " +
+        "statistics, group and replicate them to match the other side. This is an optimization " +
+        "on skew join and can help to reduce data skewness when certain partitions are assigned " +
+        s"large amount of data. This config requires both ${V2_BUCKETING_ENABLED.key} and " +
+        s"${V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key} to be enabled")
+      .version("3.4.0")
+      .booleanConf
+      .createWithDefault(false)
+
   val BUCKETING_MAX_BUCKETS = buildConf("spark.sql.sources.bucketing.maxBuckets")
     .doc("The maximum number of buckets allowed.")
     .version("2.4.0")
@@ -4627,6 +4640,9 @@ class SQLConf extends Serializable with Logging {
   def v2BucketingPushPartValuesEnabled: Boolean =
     getConf(SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED)
 
+  def v2BucketingPartiallyClusteredDistributionEnabled: Boolean =
+    getConf(SQLConf.V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED)
+
   def dataFrameSelfJoinAutoResolveAmbiguity: Boolean =
     getConf(DATAFRAME_SELF_JOIN_AUTO_RESOLVE_AMBIGUITY)
 
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala
index 1f7dd1b3092..e7c4c784b98 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala
@@ -55,7 +55,8 @@ abstract class InMemoryBaseTable(
     val distribution: Distribution = Distributions.unspecified(),
     val ordering: Array[SortOrder] = Array.empty,
     val numPartitions: Option[Int] = None,
-    val isDistributionStrictlyRequired: Boolean = true)
+    val isDistributionStrictlyRequired: Boolean = true,
+    val numRowsPerSplit: Int = Int.MaxValue)
   extends Table with SupportsRead with SupportsWrite with SupportsMetadataColumns {
 
   protected object PartitionKeyColumn extends MetadataColumn {
@@ -90,12 +91,12 @@ abstract class InMemoryBaseTable(
       throw new IllegalArgumentException(s"Transform $t is not a supported transform")
   }
 
-  // The key `Seq[Any]` is the partition values.
-  val dataMap: mutable.Map[Seq[Any], BufferedRows] = mutable.Map.empty
+  // The key `Seq[Any]` is the partition values, value is a set of splits, each with a set of rows.
+  val dataMap: mutable.Map[Seq[Any], Seq[BufferedRows]] = mutable.Map.empty
 
-  def data: Array[BufferedRows] = dataMap.values.toArray
+  def data: Array[BufferedRows] = dataMap.values.flatten.toArray
 
-  def rows: Seq[InternalRow] = dataMap.values.flatMap(_.rows).toSeq
+  def rows: Seq[InternalRow] = dataMap.values.flatten.flatMap(_.rows).toSeq
 
   val partCols: Array[Array[String]] = partitioning.flatMap(_.references).map { ref =>
     schema.findNestedField(ref.fieldNames(), includeCollections = false) match {
@@ -196,18 +197,21 @@ abstract class InMemoryBaseTable(
       partitionSchema: StructType,
       from: Seq[Any],
       to: Seq[Any]): Boolean = {
-    val rows = dataMap.remove(from).getOrElse(new BufferedRows(from))
-    val newRows = new BufferedRows(to)
-    rows.rows.foreach { r =>
-      val newRow = new GenericInternalRow(r.numFields)
-      for (i <- 0 until r.numFields) newRow.update(i, r.get(i, schema(i).dataType))
-      for (i <- 0 until partitionSchema.length) {
-        val j = schema.fieldIndex(partitionSchema(i).name)
-        newRow.update(j, to(i))
+    val splits = dataMap.remove(from).getOrElse(Seq(new BufferedRows(from)))
+    val newSplits = splits.map { rows =>
+      val newRows = new BufferedRows(to)
+      rows.rows.foreach { r =>
+        val newRow = new GenericInternalRow(r.numFields)
+        for (i <- 0 until r.numFields) newRow.update(i, r.get(i, schema(i).dataType))
+        for (i <- 0 until partitionSchema.length) {
+          val j = schema.fieldIndex(partitionSchema(i).name)
+          newRow.update(j, to(i))
+        }
+        newRows.withRow(newRow)
       }
-      newRows.withRow(newRow)
+      newRows
     }
-    dataMap.put(to, newRows).foreach { _ =>
+    dataMap.put(to, newSplits).foreach { _ =>
       throw new IllegalStateException(
         s"The ${to.mkString("[", ", ", "]")} partition exists already")
     }
@@ -224,21 +228,24 @@ abstract class InMemoryBaseTable(
       val rows = if (key.length == schema.length) {
         emptyRows.withRow(InternalRow.fromSeq(key))
       } else emptyRows
-      dataMap.put(key, rows)
+      dataMap.put(key, Seq(rows))
     }
   }
 
   protected def clearPartition(key: Seq[Any]): Unit = dataMap.synchronized {
     assert(dataMap.contains(key))
-    dataMap(key).clear()
+    dataMap.update(key, Seq(new BufferedRows(key)))
   }
 
   def withDeletes(data: Array[BufferedRows]): InMemoryBaseTable = {
     data.foreach { p =>
-      dataMap ++= dataMap.map { case (key, currentRows) =>
-        val newRows = new BufferedRows(currentRows.key)
-        newRows.rows ++= currentRows.rows.filter(r => !p.deletes.contains(r.getInt(0)))
-        key -> newRows
+      dataMap ++= dataMap.map { case (key, currentSplits) =>
+        val newSplits = currentSplits.map { currentRows =>
+          val newRows = new BufferedRows(currentRows.key)
+          newRows.rows ++= currentRows.rows.filter(r => !p.deletes.contains(r.getInt(0)))
+          newRows
+        }
+        key -> newSplits
       }
     }
     this
@@ -254,8 +261,16 @@ abstract class InMemoryBaseTable(
     data.foreach(_.rows.foreach { row =>
       val key = getKey(row, writeSchema)
       dataMap += dataMap.get(key)
-        .map(key -> _.withRow(row))
-        .getOrElse(key -> new BufferedRows(key).withRow(row))
+          .map { splits =>
+            val newSplits = if (splits.last.rows.size >= numRowsPerSplit) {
+              splits :+ new BufferedRows(key)
+            } else {
+              splits
+            }
+            newSplits.last.withRow(row)
+            key -> newSplits
+          }
+          .getOrElse(key -> Seq(new BufferedRows(key).withRow(row)))
       addPartitionKey(key)
     })
     this
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTable.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTable.scala
index cd6821c8739..318248dae05 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTable.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTable.scala
@@ -39,9 +39,10 @@ class InMemoryTable(
     distribution: Distribution = Distributions.unspecified(),
     ordering: Array[SortOrder] = Array.empty,
     numPartitions: Option[Int] = None,
-    isDistributionStrictlyRequired: Boolean = true)
+    isDistributionStrictlyRequired: Boolean = true,
+    override val numRowsPerSplit: Int = Int.MaxValue)
   extends InMemoryBaseTable(name, schema, partitioning, properties, distribution,
-    ordering, numPartitions, isDistributionStrictlyRequired) with SupportsDelete {
+    ordering, numPartitions, isDistributionStrictlyRequired, numRowsPerSplit) with SupportsDelete {
 
   override def canDeleteWhere(filters: Array[Filter]): Boolean = {
     InMemoryTable.supportsFilters(filters)
@@ -62,8 +63,16 @@ class InMemoryTable(
     data.foreach(_.rows.foreach { row =>
       val key = getKey(row, writeSchema)
       dataMap += dataMap.get(key)
-        .map(key -> _.withRow(row))
-        .getOrElse(key -> new BufferedRows(key).withRow(row))
+        .map { splits =>
+          val newSplits = if (splits.last.rows.size >= numRowsPerSplit) {
+            splits :+ new BufferedRows(key)
+          } else {
+            splits
+          }
+          newSplits.last.withRow(row)
+          key -> newSplits
+        }
+        .getOrElse(key -> Seq(new BufferedRows(key).withRow(row)))
       addPartitionKey(key)
     })
     this
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTableCatalog.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTableCatalog.scala
index af8070652da..06ee588329c 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTableCatalog.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTableCatalog.scala
@@ -101,7 +101,8 @@ class BasicInMemoryTableCatalog extends TableCatalog {
       distribution: Distribution,
       ordering: Array[SortOrder],
       requiredNumPartitions: Option[Int],
-      distributionStrictlyRequired: Boolean = true): Table = {
+      distributionStrictlyRequired: Boolean = true,
+      numRowsPerSplit: Int = Int.MaxValue): Table = {
     if (tables.containsKey(ident)) {
       throw new TableAlreadyExistsException(ident.asMultipartIdentifier)
     }
@@ -110,7 +111,7 @@ class BasicInMemoryTableCatalog extends TableCatalog {
 
     val tableName = s"$name.${ident.quoted}"
     val table = new InMemoryTable(tableName, schema, partitions, properties, distribution,
-      ordering, requiredNumPartitions, distributionStrictlyRequired)
+      ordering, requiredNumPartitions, distributionStrictlyRequired, numRowsPerSplit)
     tables.put(ident, table)
     namespaces.putIfAbsent(ident.namespace.toList, Map())
     table
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala
index 93b35337b52..52f15cf7b65 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala
@@ -28,6 +28,7 @@ import org.apache.spark.sql.catalyst.plans.physical.{KeyGroupedPartitioning, Par
 import org.apache.spark.sql.catalyst.util.{truncatedString, InternalRowComparableWrapper}
 import org.apache.spark.sql.connector.catalog.Table
 import org.apache.spark.sql.connector.read._
+import org.apache.spark.sql.internal.SQLConf
 
 /**
  * Physical plan node for scanning a batch of data from a data source v2.
@@ -39,7 +40,9 @@ case class BatchScanExec(
     keyGroupedPartitioning: Option[Seq[Expression]] = None,
     ordering: Option[Seq[SortOrder]] = None,
     @transient table: Table,
-    commonPartitionValues: Option[Seq[InternalRow]] = None) extends DataSourceV2ScanExecBase {
+    commonPartitionValues: Option[Seq[(InternalRow, Int)]] = None,
+    applyPartialClustering: Boolean = false,
+    replicatePartitions: Boolean = false) extends DataSourceV2ScanExecBase {
 
   @transient lazy val batch = scan.toBatch
 
@@ -47,7 +50,9 @@ case class BatchScanExec(
   override def equals(other: Any): Boolean = other match {
     case other: BatchScanExec =>
       this.batch == other.batch && this.runtimeFilters == other.runtimeFilters &&
-          this.commonPartitionValues == other.commonPartitionValues
+          this.commonPartitionValues == other.commonPartitionValues &&
+          this.replicatePartitions == other.replicatePartitions &&
+          this.applyPartialClustering == other.applyPartialClustering
     case _ =>
       false
   }
@@ -103,7 +108,7 @@ case class BatchScanExec(
 
         case _ =>
           // no validation is needed as the data source did not report any specific partitioning
-        newPartitions.map(Seq(_))
+          newPartitions.map(Seq(_))
       }
 
     } else {
@@ -114,8 +119,12 @@ case class BatchScanExec(
   override def outputPartitioning: Partitioning = {
     super.outputPartitioning match {
       case k: KeyGroupedPartitioning if commonPartitionValues.isDefined =>
-        val values = commonPartitionValues.get
-        k.copy(numPartitions = values.length, partitionValues = values)
+        // We allow duplicated partition values if
+        // `spark.sql.sources.v2.bucketing.partiallyClusteredDistribution.enabled` is true
+        val newPartValues = commonPartitionValues.get.flatMap { case (partValue, numSplits) =>
+          Seq.fill(numSplits)(partValue)
+        }
+        k.copy(numPartitions = newPartValues.length, partitionValues = newPartValues)
       case p => p
     }
   }
@@ -131,14 +140,77 @@ case class BatchScanExec(
 
       outputPartitioning match {
         case p: KeyGroupedPartitioning =>
-          val partitionMapping = finalPartitions.map(s => InternalRowComparableWrapper(
-            s.head.asInstanceOf[HasPartitionKey], p.expressions) -> s)
-            .toMap
-          finalPartitions = p.partitionValues.map { partValue =>
-            // Use empty partition for those partition values that are not present
-            partitionMapping.getOrElse(
-              InternalRowComparableWrapper(partValue, p.expressions), Seq.empty)
+          if (conf.v2BucketingPushPartValuesEnabled &&
+              conf.v2BucketingPartiallyClusteredDistributionEnabled) {
+            assert(filteredPartitions.forall(_.size == 1),
+              "Expect partitions to be not grouped when " +
+                  s"${SQLConf.V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED.key} " +
+                  "is enabled")
+
+            val groupedPartitions = groupPartitions(finalPartitions.map(_.head), true).get
+
+            // This means the input partitions are not grouped by partition values. We'll need to
+            // check `groupByPartitionValues` and decide whether to group and replicate splits
+            // within a partition.
+            if (commonPartitionValues.isDefined && applyPartialClustering) {
+              // A mapping from the common partition values to how many splits the partition
+              // should contain. Note this no longer maintain the partition key ordering.
+              val commonPartValuesMap = commonPartitionValues
+                .get
+                .map(t => (InternalRowComparableWrapper(t._1, p.expressions), t._2))
+                .toMap
+              val nestGroupedPartitions = groupedPartitions.map {
+                case (partValue, splits) =>
+                  // `commonPartValuesMap` should contain the part value since it's the super set.
+                  val numSplits = commonPartValuesMap
+                    .get(InternalRowComparableWrapper(partValue, p.expressions))
+                  assert(numSplits.isDefined, s"Partition value $partValue does not exist in " +
+                      "common partition values from Spark plan")
+
+                  val newSplits = if (replicatePartitions) {
+                    // We need to also replicate partitions according to the other side of join
+                    Seq.fill(numSplits.get)(splits)
+                  } else {
+                    // Not grouping by partition values: this could be the side with partially
+                    // clustered distribution. Because of dynamic filtering, we'll need to check if
+                    // the final number of splits of a partition is smaller than the original
+                    // number, and fill with empty splits if so. This is necessary so that both
+                    // sides of a join will have the same number of partitions & splits.
+                    splits.map(Seq(_)).padTo(numSplits.get, Seq.empty)
+                  }
+                  (InternalRowComparableWrapper(partValue, p.expressions), newSplits)
+              }
+
+              // Now fill missing partition keys with empty partitions
+              val partitionMapping = nestGroupedPartitions.toMap
+              finalPartitions = commonPartitionValues.get.flatMap { case (partValue, numSplits) =>
+                // Use empty partition for those partition values that are not present.
+                partitionMapping.getOrElse(
+                  InternalRowComparableWrapper(partValue, p.expressions),
+                  Seq.fill(numSplits)(Seq.empty))
+              }
+            } else {
+              val partitionMapping = groupedPartitions.map { case (row, parts) =>
+                InternalRowComparableWrapper(row, p.expressions) -> parts
+              }.toMap
+              finalPartitions = p.partitionValues.map { partValue =>
+                // Use empty partition for those partition values that are not present
+                partitionMapping.getOrElse(
+                  InternalRowComparableWrapper(partValue, p.expressions), Seq.empty)
+              }
+            }
+          } else {
+            val partitionMapping = finalPartitions.map { parts =>
+              val row = parts.head.asInstanceOf[HasPartitionKey].partitionKey()
+              InternalRowComparableWrapper(row, p.expressions) -> parts
+            }.toMap
+            finalPartitions = p.partitionValues.map { partValue =>
+              // Use empty partition for those partition values that are not present
+              partitionMapping.getOrElse(
+                InternalRowComparableWrapper(partValue, p.expressions), Seq.empty)
+            }
           }
+
         case _ =>
       }
 
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExecBase.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExecBase.scala
index 8a7c4729a0e..e539b1c4ee3 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExecBase.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExecBase.scala
@@ -121,7 +121,11 @@ trait DataSourceV2ScanExecBase extends LeafExecNode {
    * for further optimizations to eliminate shuffling in some operations such as join and aggregate.
    */
   def groupPartitions(
-      inputPartitions: Seq[InputPartition]): Option[Seq[(InternalRow, Seq[InputPartition])]] = {
+      inputPartitions: Seq[InputPartition],
+      groupSplits: Boolean = !conf.v2BucketingPushPartValuesEnabled ||
+          !conf.v2BucketingPartiallyClusteredDistributionEnabled):
+    Option[Seq[(InternalRow, Seq[InputPartition])]] = {
+
     if (!SQLConf.get.v2BucketingEnabled) return None
     keyGroupedPartitioning.flatMap { expressions =>
       val results = inputPartitions.takeWhile {
@@ -133,21 +137,28 @@ trait DataSourceV2ScanExecBase extends LeafExecNode {
         // Not all of the `InputPartitions` implements `HasPartitionKey`, therefore skip here.
         None
       } else {
-        val groupedPartitions = inputPartitions.map(p =>
-            (InternalRowComparableWrapper(p.asInstanceOf[HasPartitionKey], expressions), p))
-          .groupBy(_._1)
-          .toSeq
-          .map {
-            case (key, s) => (key.row, s.map(_._2))
-          }
-
         // also sort the input partitions according to their partition key order. This ensures
         // a canonical order from both sides of a bucketed join, for example.
         val partitionDataTypes = expressions.map(_.dataType)
         val partitionOrdering: Ordering[(InternalRow, Seq[InputPartition])] = {
           RowOrdering.createNaturalAscendingOrdering(partitionDataTypes).on(_._1)
         }
-        Some(groupedPartitions.sorted(partitionOrdering))
+
+        val partitions = if (groupSplits) {
+          // Group the splits by their partition value
+          results
+            .map(t => (InternalRowComparableWrapper(t._1, expressions), t._2))
+            .groupBy(_._1)
+            .toSeq
+            .map {
+              case (key, s) => (key.row, s.map(_._2))
+            }
+        } else {
+          // No splits grouping, each split will become a separate Spark partition
+          results.map(t => (t._1, Seq(t._2)))
+        }
+
+        Some(partitions.sorted(partitionOrdering))
       }
     }
   }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala
index 4b229fcbfcd..bc90a869fd9 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala
@@ -22,6 +22,7 @@ import scala.collection.mutable.ArrayBuffer
 
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.plans._
 import org.apache.spark.sql.catalyst.plans.physical._
 import org.apache.spark.sql.catalyst.rules.Rule
 import org.apache.spark.sql.catalyst.util.InternalRowComparableWrapper
@@ -51,6 +52,7 @@ case class EnsureRequirements(
   extends Rule[SparkPlan] {
 
   private def ensureDistributionAndOrdering(
+      parent: Option[SparkPlan],
       originalChildren: Seq[SparkPlan],
       requiredChildDistributions: Seq[Distribution],
       requiredChildOrderings: Seq[Seq[SortOrder]],
@@ -149,79 +151,25 @@ case class EnsureRequirements(
         Some(finalCandidateSpecs.values.maxBy(_.numPartitions))
       }
 
-      // Retrieve the non-collection spec from the input
-      def getRootSpec(spec: ShuffleSpec): ShuffleSpec = spec match {
-        case ShuffleSpecCollection(specs) => getRootSpec(specs.head)
-        case spec => spec
-      }
-
-      // Populate the common partition values down to the scan nodes
-      def populatePartitionValues(plan: SparkPlan, values: Seq[InternalRow]): SparkPlan =
-        plan match {
-          case scan: BatchScanExec =>
-            scan.copy(commonPartitionValues = Some(values))
-          case node =>
-            node.mapChildren(child => populatePartitionValues(child, values))
-        }
-
       // Check if the following conditions are satisfied:
       //   1. There are exactly two children (e.g., join). Note that Spark doesn't support
       //      multi-way join at the moment, so this check should be sufficient.
       //   2. All children are of `KeyGroupedPartitioning`, and they are compatible with each other
       // If both are true, skip shuffle.
-      val allCompatible = childrenIndexes.length == 2 && {
-        val left = childrenIndexes.head
-        val right = childrenIndexes(1)
-        var isCompatible: Boolean = false
-
-        if (checkKeyGroupedSpec(specs(left)) && checkKeyGroupedSpec(specs(right))) {
-          isCompatible = specs(left).isCompatibleWith(specs(right))
-
-          // If `isCompatible` is false, it could mean:
-          //   1. Partition keys (expressions) are not compatible: we have to shuffle in this case.
-          //   2. Partition keys (expressions) are compatible, but partition values are not: in this
-          //      case we can compute a superset of partition values and push-down to respective
-          //      data sources, which can then adjust their respective output partitioning by
-          //      filling missing partition values with empty partitions. As result, Spark can still
-          //      avoid shuffle.
-          //
-          // For instance, if two sides of a join have partition expressions `day(a)` and `day(b)`
-          // respectively (the join query could be `SELECT ... FROM t1 JOIN t2 on t1.a = t2.b`),
-          // but with different partition values:
-          //   `day(a)`: [0, 1]
-          //   `day(b)`: [1, 2, 3]
-          // Following the case 2 above, we don't have to shuffle both sides, but instead can just
-          // push the common set of partition values: `[0, 1, 2, 3]` down to the two data sources.
-          if (!isCompatible && conf.v2BucketingPushPartValuesEnabled) {
-            (getRootSpec(specs(left)), getRootSpec(specs(right))) match {
-              case (leftSpec: KeyGroupedShuffleSpec, rightSpec: KeyGroupedShuffleSpec) =>
-                // Check if the two children are partition keys compatible. If so, find the
-                // common set of partition values, and adjust the plan accordingly.
-                if (leftSpec.areKeysCompatible(rightSpec)) {
-                  val mergedPartValues = InternalRowComparableWrapper.mergePartitions(
-                    leftSpec.partitioning, rightSpec.partitioning,
-                    leftSpec.partitioning.expressions)
-                  // Now we need to push-down the common partition key to the scan in each child
-                  children = children.zipWithIndex.map {
-                    case (child, idx) if childrenIndexes.contains(idx) =>
-                      populatePartitionValues(child, mergedPartValues)
-                    case (child, _) => child
-                  }
-
-                  isCompatible = true
-                }
-              case _ =>
-                // This case shouldn't happen since `checkKeyGroupedSpec` should've made
-                // sure that we only have `KeyGroupedShuffleSpec`
-            }
-          }
+      val isKeyGroupCompatible = parent.isDefined &&
+          children.length == 2 && childrenIndexes.length == 2 && {
+        val left = children.head
+        val right = children(1)
+        val newChildren = checkKeyGroupCompatible(
+          parent.get, left, right, requiredChildDistributions)
+        if (newChildren.isDefined) {
+          children = newChildren.get
         }
-
-        isCompatible
+        newChildren.isDefined
       }
 
       children = children.zip(requiredChildDistributions).zipWithIndex.map {
-        case ((child, _), idx) if allCompatible || !childrenIndexes.contains(idx) =>
+        case ((child, _), idx) if isKeyGroupCompatible || !childrenIndexes.contains(idx) =>
           child
         case ((child, dist), idx) =>
           if (bestSpecOpt.isDefined && bestSpecOpt.get.isCompatibleWith(specs(idx))) {
@@ -260,26 +208,6 @@ case class EnsureRequirements(
     children
   }
 
-  private def checkKeyGroupedSpec(shuffleSpec: ShuffleSpec): Boolean = {
-    def check(spec: KeyGroupedShuffleSpec): Boolean = {
-      val attributes = spec.partitioning.expressions.flatMap(_.collectLeaves())
-      val clustering = spec.distribution.clustering
-
-      if (SQLConf.get.getConf(SQLConf.REQUIRE_ALL_CLUSTER_KEYS_FOR_CO_PARTITION)) {
-        attributes.length == clustering.length && attributes.zip(clustering).forall {
-          case (l, r) => l.semanticEquals(r)
-        }
-      } else {
-        true // already validated in `KeyGroupedPartitioning.satisfies`
-      }
-    }
-    shuffleSpec match {
-      case spec: KeyGroupedShuffleSpec => check(spec)
-      case ShuffleSpecCollection(specs) => specs.exists(checkKeyGroupedSpec)
-      case _ => false
-    }
-  }
-
   private def reorder(
       leftKeys: IndexedSeq[Expression],
       rightKeys: IndexedSeq[Expression],
@@ -408,6 +336,246 @@ case class EnsureRequirements(
     }
   }
 
+  /**
+   * Checks whether two children, `left` and `right`, of a join operator have compatible
+   * `KeyGroupedPartitioning`, and can benefit from storage-partitioned join.
+   *
+   * Returns the updated new children if the check is successful, otherwise `None`.
+   */
+  private def checkKeyGroupCompatible(
+      parent: SparkPlan,
+      left: SparkPlan,
+      right: SparkPlan,
+      requiredChildDistribution: Seq[Distribution]): Option[Seq[SparkPlan]] = {
+    parent match {
+      case smj: SortMergeJoinExec =>
+        checkKeyGroupCompatible(left, right, smj.joinType, requiredChildDistribution)
+      case sj: ShuffledHashJoinExec =>
+        checkKeyGroupCompatible(left, right, sj.joinType, requiredChildDistribution)
+      case _ =>
+        None
+    }
+  }
+
+  private def checkKeyGroupCompatible(
+      left: SparkPlan,
+      right: SparkPlan,
+      joinType: JoinType,
+      requiredChildDistribution: Seq[Distribution]): Option[Seq[SparkPlan]] = {
+    assert(requiredChildDistribution.length == 2)
+
+    var newLeft = left
+    var newRight = right
+
+    val specs = Seq(left, right).zip(requiredChildDistribution).map { case (p, d) =>
+      if (!d.isInstanceOf[ClusteredDistribution]) return None
+      val cd = d.asInstanceOf[ClusteredDistribution]
+      val specOpt = createKeyGroupedShuffleSpec(p.outputPartitioning, cd)
+      if (specOpt.isEmpty) return None
+      specOpt.get
+    }
+
+    val leftSpec = specs.head
+    val rightSpec = specs(1)
+
+    var isCompatible = false
+    if (!conf.v2BucketingPushPartValuesEnabled) {
+      isCompatible = leftSpec.isCompatibleWith(rightSpec)
+    } else {
+      logInfo("Pushing common partition values for storage-partitioned join")
+      isCompatible = leftSpec.areKeysCompatible(rightSpec)
+
+      // Partition expressions are compatible. Regardless of whether partition values
+      // match from both sides of children, we can calculate a superset of partition values and
+      // push-down to respective data sources so they can adjust their output partitioning by
+      // filling missing partition keys with empty partitions. As result, we can still avoid
+      // shuffle.
+      //
+      // For instance, if two sides of a join have partition expressions
+      // `day(a)` and `day(b)` respectively
+      // (the join query could be `SELECT ... FROM t1 JOIN t2 on t1.a = t2.b`), but
+      // with different partition values:
+      //   `day(a)`: [0, 1]
+      //   `day(b)`: [1, 2, 3]
+      // Following the case 2 above, we don't have to shuffle both sides, but instead can
+      // just push the common set of partition values: `[0, 1, 2, 3]` down to the two data
+      // sources.
+      if (isCompatible) {
+        val leftPartValues = leftSpec.partitioning.partitionValues
+        val rightPartValues = rightSpec.partitioning.partitionValues
+
+        logInfo(
+          s"""
+             |Left side # of partitions: ${leftPartValues.size}
+             |Right side # of partitions: ${rightPartValues.size}
+             |""".stripMargin)
+
+        // As partition keys are compatible, we can pick either left or right as partition
+        // expressions
+        val partitionExprs = leftSpec.partitioning.expressions
+
+        var mergedPartValues = InternalRowComparableWrapper
+            .mergePartitions(leftSpec.partitioning, rightSpec.partitioning, partitionExprs)
+            .map(v => (v, 1))
+
+        logInfo(s"After merging, there are ${mergedPartValues.size} partitions")
+
+        var replicateLeftSide = false
+        var replicateRightSide = false
+        var applyPartialClustering = false
+
+        // This means we allow partitions that are not clustered on their values,
+        // that is, multiple partitions with the same partition value. In the
+        // following, we calculate how many partitions that each distinct partition
+        // value has, and pushdown the information to scans, so they can adjust their
+        // final input partitions respectively.
+        if (conf.v2BucketingPartiallyClusteredDistributionEnabled) {
+          logInfo("Calculating partially clustered distribution for " +
+              "storage-partitioned join")
+
+          // Similar to `OptimizeSkewedJoin`, we need to check join type and decide
+          // whether partially clustered distribution can be applied. For instance, the
+          // optimization cannot be applied to a left outer join, where the left hand
+          // side is chosen as the side to replicate partitions according to stats.
+          // Otherwise, query result could be incorrect.
+          val canReplicateLeft = canReplicateLeftSide(joinType)
+          val canReplicateRight = canReplicateRightSide(joinType)
+
+          if (!canReplicateLeft && !canReplicateRight) {
+            logInfo("Skipping partially clustered distribution as it cannot be applied for " +
+                s"join type '$joinType'")
+          } else {
+            val leftLink = left.logicalLink
+            val rightLink = right.logicalLink
+
+            replicateLeftSide = if (
+              leftLink.isDefined && rightLink.isDefined &&
+                  leftLink.get.stats.sizeInBytes > 1 &&
+                  rightLink.get.stats.sizeInBytes > 1) {
+              logInfo(
+                s"""
+                   |Using plan statistics to determine which side of join to fully
+                   |cluster partition values:
+                   |Left side size (in bytes): ${leftLink.get.stats.sizeInBytes}
+                   |Right side size (in bytes): ${rightLink.get.stats.sizeInBytes}
+                   |""".stripMargin)
+              leftLink.get.stats.sizeInBytes < rightLink.get.stats.sizeInBytes
+            } else {
+              // As a simple heuristic, we pick the side with fewer number of partitions
+              // to apply the grouping & replication of partitions
+              logInfo("Using number of partitions to determine which side of join " +
+                  "to fully cluster partition values")
+              leftPartValues.size < rightPartValues.size
+            }
+
+            replicateRightSide = !replicateLeftSide
+
+            // Similar to skewed join, we need to check the join type to see whether replication
+            // of partitions can be applied. For instance, replication should not be allowed for
+            // the left-hand side of a right outer join.
+            if (replicateLeftSide && !canReplicateLeft) {
+              logInfo("Left-hand side is picked but cannot be applied to join type " +
+                  s"'$joinType'. Skipping partially clustered distribution.")
+              replicateLeftSide = false
+            } else if (replicateRightSide && !canReplicateRight) {
+              logInfo("Right-hand side is picked but cannot be applied to join type " +
+                  s"'$joinType'. Skipping partially clustered distribution.")
+              replicateRightSide = false
+            } else {
+              val partValues = if (replicateLeftSide) rightPartValues else leftPartValues
+              val numExpectedPartitions = partValues
+                .map(InternalRowComparableWrapper(_, partitionExprs))
+                .groupBy(identity)
+                .mapValues(_.size)
+
+              mergedPartValues = mergedPartValues.map { case (partVal, numParts) =>
+                (partVal, numExpectedPartitions.getOrElse(
+                  InternalRowComparableWrapper(partVal, partitionExprs), numParts))
+              }
+
+              logInfo("After applying partially clustered distribution, there are " +
+                  s"${mergedPartValues.map(_._2).sum} partitions.")
+              applyPartialClustering = true
+            }
+          }
+        }
+
+        // Now we need to push-down the common partition key to the scan in each child
+        newLeft = populatePartitionValues(
+          left, mergedPartValues, applyPartialClustering, replicateLeftSide)
+        newRight = populatePartitionValues(
+          right, mergedPartValues, applyPartialClustering, replicateRightSide)
+      }
+    }
+
+    if (isCompatible) Some(Seq(newLeft, newRight)) else None
+  }
+
+  // Similar to `OptimizeSkewedJoin.canSplitRightSide`
+  private def canReplicateLeftSide(joinType: JoinType): Boolean = {
+    joinType == Inner || joinType == Cross || joinType == RightOuter
+  }
+
+  // Similar to `OptimizeSkewedJoin.canSplitLeftSide`
+  private def canReplicateRightSide(joinType: JoinType): Boolean = {
+    joinType == Inner || joinType == Cross || joinType == LeftSemi ||
+        joinType == LeftAnti || joinType == LeftOuter
+  }
+
+  // Populate the common partition values down to the scan nodes
+  private def populatePartitionValues(
+      plan: SparkPlan,
+      values: Seq[(InternalRow, Int)],
+      applyPartialClustering: Boolean,
+      replicatePartitions: Boolean): SparkPlan = plan match {
+    case scan: BatchScanExec =>
+      scan.copy(
+        commonPartitionValues = Some(values),
+        applyPartialClustering = applyPartialClustering,
+        replicatePartitions = replicatePartitions
+      )
+    case node =>
+      node.mapChildren(child => populatePartitionValues(
+        child, values, applyPartialClustering, replicatePartitions))
+  }
+
+  /**
+   * Tries to create a [[KeyGroupedShuffleSpec]] from the input partitioning and distribution, if
+   * the partitioning is a [[KeyGroupedPartitioning]] (either directly or indirectly), and
+   * satisfies the given distribution.
+   */
+  private def createKeyGroupedShuffleSpec(
+      partitioning: Partitioning,
+      distribution: ClusteredDistribution): Option[KeyGroupedShuffleSpec] = {
+    def check(partitioning: KeyGroupedPartitioning): Option[KeyGroupedShuffleSpec] = {
+      val attributes = partitioning.expressions.flatMap(_.collectLeaves())
+      val clustering = distribution.clustering
+
+      val satisfies = if (SQLConf.get.getConf(SQLConf.REQUIRE_ALL_CLUSTER_KEYS_FOR_CO_PARTITION)) {
+        attributes.length == clustering.length && attributes.zip(clustering).forall {
+          case (l, r) => l.semanticEquals(r)
+        }
+      } else {
+        partitioning.satisfies(distribution)
+      }
+
+      if (satisfies) {
+        Some(partitioning.createShuffleSpec(distribution).asInstanceOf[KeyGroupedShuffleSpec])
+      } else {
+        None
+      }
+    }
+
+    partitioning match {
+      case p: KeyGroupedPartitioning => check(p)
+      case PartitioningCollection(partitionings) =>
+        val specs = partitionings.map(p => createKeyGroupedShuffleSpec(p, distribution))
+        assert(specs.forall(_.isEmpty) || specs.forall(_.isDefined))
+        specs.head
+      case _ => None
+    }
+  }
+
   def apply(plan: SparkPlan): SparkPlan = {
     val newPlan = plan.transformUp {
       case operator @ ShuffleExchangeExec(upper: HashPartitioning, child, shuffleOrigin)
@@ -430,6 +598,7 @@ case class EnsureRequirements(
       case operator: SparkPlan =>
         val reordered = reorderJoinPredicates(operator)
         val newChildren = ensureDistributionAndOrdering(
+          Some(reordered),
           reordered.children,
           reordered.requiredChildDistribution,
           reordered.requiredChildOrdering,
@@ -444,6 +613,7 @@ case class EnsureRequirements(
         REPARTITION_BY_COL
       }
       val finalPlan = ensureDistributionAndOrdering(
+        None,
         newPlan :: Nil,
         requiredDistribution.get :: Nil,
         Seq(Nil),
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWriterV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWriterV2Suite.scala
index 913a77cedb7..507207a2fdd 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWriterV2Suite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWriterV2Suite.scala
@@ -716,8 +716,8 @@ class DataFrameWriterV2Suite extends QueryTest with SharedSparkSession with Befo
     assert(table.partitioning === Seq(IdentityTransform(FieldReference(Array("ts", "timezone")))))
     checkAnswer(spark.table(table.name), data)
     assert(table.dataMap.toArray.length == 2)
-    assert(table.dataMap(Seq(UTF8String.fromString("America/Los_Angeles"))).rows.size == 2)
-    assert(table.dataMap(Seq(UTF8String.fromString("America/New_York"))).rows.size == 1)
+    assert(table.dataMap(Seq(UTF8String.fromString("America/Los_Angeles"))).head.rows.size == 2)
+    assert(table.dataMap(Seq(UTF8String.fromString("America/New_York"))).head.rows.size == 1)
 
     // TODO: `DataSourceV2Strategy` can not translate nested fields into source filter yet
     // so the following sql will fail.
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala
index 474de0dacae..2796b1cf154 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala
@@ -874,7 +874,7 @@ class FileBasedDataSourceSuite extends QueryTest
           })
 
           val fileScan = df.queryExecution.executedPlan collectFirst {
-            case BatchScanExec(_, f: FileScan, _, _, _, _, _) => f
+            case BatchScanExec(_, f: FileScan, _, _, _, _, _, _, _) => f
           }
           assert(fileScan.nonEmpty)
           assert(fileScan.get.partitionFilters.nonEmpty)
@@ -915,7 +915,7 @@ class FileBasedDataSourceSuite extends QueryTest
           assert(filterCondition.isDefined)
 
           val fileScan = df.queryExecution.executedPlan collectFirst {
-            case BatchScanExec(_, f: FileScan, _, _, _, _, _) => f
+            case BatchScanExec(_, f: FileScan, _, _, _, _, _, _, _) => f
           }
           assert(fileScan.nonEmpty)
           assert(fileScan.get.partitionFilters.isEmpty)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala
index 6cb2313f487..09be936a0f2 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala
@@ -240,7 +240,8 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase {
       partitions: Array[Transform],
       catalog: InMemoryTableCatalog = catalog): Unit = {
     catalog.createTable(Identifier.of(Array("ns"), table),
-      schema, partitions, emptyProps, Distributions.unspecified(), Array.empty, None)
+      schema, partitions, emptyProps, Distributions.unspecified(), Array.empty, None,
+      numRowsPerSplit = 1)
   }
 
   private val customers: String = "customers"
@@ -288,6 +289,10 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase {
       })
   }
 
+  private def collectScans(plan: SparkPlan): Seq[BatchScanExec] = {
+    collect(plan) { case s: BatchScanExec => s }
+  }
+
   test("partitioned join: exact distribution (same number of buckets) from both sides") {
     val customers_partitions = Array(bucket(4, "customer_id"))
     val orders_partitions = Array(bucket(4, "customer_id"))
@@ -536,9 +541,399 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase {
     }
   }
 
+  test("SPARK-42038: partially clustered: with same partition keys and one side fully clustered") {
+    val items_partitions = Array(identity("id"))
+    createTable(items, items_schema, items_partitions)
+    sql(s"INSERT INTO testcat.ns.$items VALUES " +
+        s"(1, 'aa', 40.0, cast('2020-01-01' as timestamp)), " +
+        s"(2, 'bb', 10.0, cast('2020-01-01' as timestamp)), " +
+        s"(3, 'cc', 15.5, cast('2020-02-01' as timestamp))")
+
+    val purchases_partitions = Array(identity("item_id"))
+    createTable(purchases, purchases_schema, purchases_partitions)
+    sql(s"INSERT INTO testcat.ns.$purchases VALUES " +
+        s"(1, 45.0, cast('2020-01-01' as timestamp)), " +
+        s"(1, 50.0, cast('2020-01-02' as timestamp)), " +
+        s"(2, 15.0, cast('2020-01-02' as timestamp)), " +
+        s"(2, 20.0, cast('2020-01-03' as timestamp)), " +
+        s"(3, 20.0, cast('2020-02-01' as timestamp))")
+
+    Seq(true, false).foreach { pushDownValues =>
+      Seq(("true", 5), ("false", 3)).foreach {
+        case (enable, expected) =>
+          withSQLConf(
+              SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> pushDownValues.toString,
+              SQLConf.V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED.key -> enable) {
+            val df = sql("SELECT id, name, i.price as purchase_price, p.price as sale_price " +
+                s"FROM testcat.ns.$items i JOIN testcat.ns.$purchases p " +
+                "ON i.id = p.item_id ORDER BY id, purchase_price, sale_price")
+
+            val shuffles = collectShuffles(df.queryExecution.executedPlan)
+            assert(shuffles.isEmpty, "should not contain any shuffle")
+            if (pushDownValues) {
+              val scans = collectScans(df.queryExecution.executedPlan)
+              assert(scans.forall(_.inputRDD.partitions.length == expected))
+            }
+            checkAnswer(df, Seq(Row(1, "aa", 40.0, 45.0), Row(1, "aa", 40.0, 50.0),
+              Row(2, "bb", 10.0, 15.0), Row(2, "bb", 10.0, 20.0), Row(3, "cc", 15.5, 20.0)))
+          }
+      }
+    }
+  }
+
+  test("SPARK-42038: partially clustered: with same partition keys and both sides partially " +
+      "clustered") {
+    val items_partitions = Array(identity("id"))
+    createTable(items, items_schema, items_partitions)
+    sql(s"INSERT INTO testcat.ns.$items VALUES " +
+        s"(1, 'aa', 40.0, cast('2020-01-01' as timestamp)), " +
+        s"(1, 'aa', 41.0, cast('2020-01-02' as timestamp)), " +
+        s"(2, 'bb', 10.0, cast('2020-01-01' as timestamp)), " +
+        s"(3, 'cc', 15.5, cast('2020-02-01' as timestamp))")
+
+    val purchases_partitions = Array(identity("item_id"))
+    createTable(purchases, purchases_schema, purchases_partitions)
+    sql(s"INSERT INTO testcat.ns.$purchases VALUES " +
+        s"(1, 45.0, cast('2020-01-01' as timestamp)), " +
+        s"(1, 50.0, cast('2020-01-02' as timestamp)), " +
+        s"(1, 55.0, cast('2020-01-02' as timestamp)), " +
+        s"(2, 15.0, cast('2020-01-02' as timestamp)), " +
+        s"(2, 20.0, cast('2020-01-03' as timestamp)), " +
+        s"(2, 22.0, cast('2020-01-03' as timestamp)), " +
+        s"(3, 20.0, cast('2020-02-01' as timestamp))")
+
+    Seq(true, false).foreach { pushDownValues =>
+      Seq(("true", 7), ("false", 3)).foreach {
+        case (enable, expected) =>
+          withSQLConf(
+              SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> pushDownValues.toString,
+              SQLConf.V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED.key -> enable) {
+            val df = sql("SELECT id, name, i.price as purchase_price, p.price as sale_price " +
+                s"FROM testcat.ns.$items i JOIN testcat.ns.$purchases p " +
+                "ON i.id = p.item_id ORDER BY id, purchase_price, sale_price")
+
+            val shuffles = collectShuffles(df.queryExecution.executedPlan)
+            assert(shuffles.isEmpty, "should not contain any shuffle")
+            if (pushDownValues) {
+              val scans = collectScans(df.queryExecution.executedPlan)
+              assert(scans.forall(_.inputRDD.partitions.length == expected))
+            }
+            checkAnswer(df, Seq(
+              Row(1, "aa", 40.0, 45.0), Row(1, "aa", 40.0, 50.0), Row(1, "aa", 40.0, 55.0),
+              Row(1, "aa", 41.0, 45.0), Row(1, "aa", 41.0, 50.0), Row(1, "aa", 41.0, 55.0),
+              Row(2, "bb", 10.0, 15.0), Row(2, "bb", 10.0, 20.0), Row(2, "bb", 10.0, 22.0),
+              Row(3, "cc", 15.5, 20.0)))
+          }
+      }
+    }
+  }
+
+  test("SPARK-42038: partially clustered: with different partition keys and both sides partially " +
+      "clustered") {
+    val items_partitions = Array(identity("id"))
+    createTable(items, items_schema, items_partitions)
+    sql(s"INSERT INTO testcat.ns.$items VALUES " +
+        s"(1, 'aa', 40.0, cast('2020-01-01' as timestamp)), " +
+        s"(1, 'aa', 41.0, cast('2020-01-02' as timestamp)), " +
+        s"(2, 'bb', 10.0, cast('2020-01-01' as timestamp)), " +
+        s"(3, 'cc', 15.5, cast('2020-02-01' as timestamp)), " +
+        s"(4, 'dd', 18.0, cast('2023-01-01' as timestamp))")
+
+    val purchases_partitions = Array(identity("item_id"))
+    createTable(purchases, purchases_schema, purchases_partitions)
+    sql(s"INSERT INTO testcat.ns.$purchases VALUES " +
+        s"(1, 45.0, cast('2020-01-01' as timestamp)), " +
+        s"(1, 50.0, cast('2020-01-02' as timestamp)), " +
+        s"(1, 55.0, cast('2020-01-02' as timestamp)), " +
+        s"(2, 15.0, cast('2020-01-02' as timestamp)), " +
+        s"(2, 20.0, cast('2020-01-03' as timestamp)), " +
+        s"(2, 25.0, cast('2020-01-03' as timestamp)), " +
+        s"(2, 30.0, cast('2020-01-03' as timestamp)), " +
+        s"(3, 20.0, cast('2020-02-01' as timestamp)), " +
+        s"(5, 30.0, cast('2023-01-01' as timestamp))")
+
+    Seq(true, false).foreach { pushDownValues =>
+      Seq(("true", 10), ("false", 5)).foreach {
+        case (enable, expected) =>
+          withSQLConf(
+              SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> pushDownValues.toString,
+              SQLConf.V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED.key -> enable) {
+            val df = sql("SELECT id, name, i.price as purchase_price, p.price as sale_price " +
+                s"FROM testcat.ns.$items i JOIN testcat.ns.$purchases p " +
+                "ON i.id = p.item_id ORDER BY id, purchase_price, sale_price")
+
+            val shuffles = collectShuffles(df.queryExecution.executedPlan)
+            if (pushDownValues) {
+              assert(shuffles.isEmpty, "should not contain any shuffle")
+              val scans = collectScans(df.queryExecution.executedPlan)
+              assert(scans.forall(_.inputRDD.partitions.length == expected))
+            } else {
+              assert(shuffles.nonEmpty,
+                "should contain shuffle when not pushing down partition values")
+            }
+            checkAnswer(df, Seq(
+              Row(1, "aa", 40.0, 45.0), Row(1, "aa", 40.0, 50.0), Row(1, "aa", 40.0, 55.0),
+              Row(1, "aa", 41.0, 45.0), Row(1, "aa", 41.0, 50.0), Row(1, "aa", 41.0, 55.0),
+              Row(2, "bb", 10.0, 15.0), Row(2, "bb", 10.0, 20.0), Row(2, "bb", 10.0, 25.0),
+              Row(2, "bb", 10.0, 30.0), Row(3, "cc", 15.5, 20.0)))
+          }
+      }
+    }
+  }
+
+  test("SPARK-42038: partially clustered: with different partition keys and missing keys on " +
+      "left-hand side") {
+    val items_partitions = Array(identity("id"))
+    createTable(items, items_schema, items_partitions)
+    sql(s"INSERT INTO testcat.ns.$items VALUES " +
+        s"(1, 'aa', 40.0, cast('2020-01-01' as timestamp)), " +
+        s"(1, 'aa', 41.0, cast('2020-01-02' as timestamp)), " +
+        s"(3, 'cc', 15.5, cast('2020-02-01' as timestamp)), " +
+        s"(4, 'dd', 18.0, cast('2023-01-01' as timestamp))")
+
+    val purchases_partitions = Array(identity("item_id"))
+    createTable(purchases, purchases_schema, purchases_partitions)
+    sql(s"INSERT INTO testcat.ns.$purchases VALUES " +
+        s"(1, 45.0, cast('2020-01-01' as timestamp)), " +
+        s"(1, 50.0, cast('2020-01-02' as timestamp)), " +
+        s"(2, 15.0, cast('2020-01-02' as timestamp)), " +
+        s"(2, 20.0, cast('2020-01-03' as timestamp)), " +
+        s"(2, 25.0, cast('2020-01-03' as timestamp)), " +
+        s"(2, 30.0, cast('2020-01-03' as timestamp)), " +
+        s"(3, 20.0, cast('2020-02-01' as timestamp)), " +
+        s"(5, 30.0, cast('2023-01-01' as timestamp))")
+
+    Seq(true, false).foreach { pushDownValues =>
+      Seq(("true", 9), ("false", 5)).foreach {
+        case (enable, expected) =>
+          withSQLConf(
+              SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> pushDownValues.toString,
+              SQLConf.V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED.key -> enable) {
+            val df = sql("SELECT id, name, i.price as purchase_price, p.price as sale_price " +
+                s"FROM testcat.ns.$items i JOIN testcat.ns.$purchases p " +
+                "ON i.id = p.item_id ORDER BY id, purchase_price, sale_price")
+
+            val shuffles = collectShuffles(df.queryExecution.executedPlan)
+            if (pushDownValues) {
+              assert(shuffles.isEmpty, "should not contain any shuffle")
+              val scans = collectScans(df.queryExecution.executedPlan)
+              assert(scans.forall(_.inputRDD.partitions.length == expected))
+            } else {
+              assert(shuffles.nonEmpty,
+                "should contain shuffle when not pushing down partition values")
+            }
+            checkAnswer(df, Seq(
+              Row(1, "aa", 40.0, 45.0), Row(1, "aa", 40.0, 50.0),
+              Row(1, "aa", 41.0, 45.0), Row(1, "aa", 41.0, 50.0),
+              Row(3, "cc", 15.5, 20.0)))
+          }
+      }
+    }
+  }
+
+  test("SPARK-42038: partially clustered: with different partition keys and missing keys on " +
+      "right-hand side") {
+    val items_partitions = Array(identity("id"))
+    createTable(items, items_schema, items_partitions)
+    sql(s"INSERT INTO testcat.ns.$items VALUES " +
+        s"(1, 'aa', 40.0, cast('2020-01-01' as timestamp)), " +
+        s"(1, 'aa', 41.0, cast('2020-01-02' as timestamp)), " +
+        s"(2, 'bb', 10.0, cast('2020-01-01' as timestamp)), " +
+        s"(3, 'cc', 15.5, cast('2020-02-01' as timestamp))")
+
+    val purchases_partitions = Array(identity("item_id"))
+    createTable(purchases, purchases_schema, purchases_partitions)
+    sql(s"INSERT INTO testcat.ns.$purchases VALUES " +
+        s"(2, 15.0, cast('2020-01-02' as timestamp)), " +
+        s"(2, 20.0, cast('2020-01-03' as timestamp)), " +
+        s"(3, 20.0, cast('2020-02-01' as timestamp)), " +
+        s"(4, 25.0, cast('2020-02-01' as timestamp)), " +
+        s"(5, 30.0, cast('2023-01-01' as timestamp))")
+
+    Seq(true, false).foreach { pushDownValues =>
+      Seq(("true", 6), ("false", 5)).foreach {
+        case (enable, expected) =>
+          withSQLConf(
+              SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> pushDownValues.toString,
+              SQLConf.V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED.key -> enable) {
+            val df = sql("SELECT id, name, i.price as purchase_price, p.price as sale_price " +
+                s"FROM testcat.ns.$items i JOIN testcat.ns.$purchases p " +
+                "ON i.id = p.item_id ORDER BY id, purchase_price, sale_price")
+
+            val shuffles = collectShuffles(df.queryExecution.executedPlan)
+            if (pushDownValues) {
+              assert(shuffles.isEmpty, "should not contain any shuffle")
+              val scans = collectScans(df.queryExecution.executedPlan)
+              assert(scans.forall(_.inputRDD.partitions.length == expected))
+            } else {
+              assert(shuffles.nonEmpty,
+                "should contain shuffle when not pushing down partition values")
+            }
+            checkAnswer(df, Seq(
+              Row(2, "bb", 10.0, 15.0), Row(2, "bb", 10.0, 20.0), Row(3, "cc", 15.5, 20.0)))
+          }
+      }
+    }
+  }
+
+  test("SPARK-42038: partially clustered: left outer join") {
+    val items_partitions = Array(identity("id"))
+    createTable(items, items_schema, items_partitions)
+    sql(s"INSERT INTO testcat.ns.$items VALUES " +
+        s"(1, 'aa', 40.0, cast('2020-01-01' as timestamp)), " +
+        s"(1, 'aa', 41.0, cast('2020-01-02' as timestamp)), " +
+        s"(2, 'bb', 10.0, cast('2020-01-01' as timestamp)), " +
+        s"(2, 'bb', 15.0, cast('2020-01-02' as timestamp)), " +
+        s"(3, 'cc', 15.5, cast('2020-02-01' as timestamp))")
+
+    val purchases_partitions = Array(identity("item_id"))
+    createTable(purchases, purchases_schema, purchases_partitions)
+    sql(s"INSERT INTO testcat.ns.$purchases VALUES " +
+        s"(2, 20.0, cast('2020-01-01' as timestamp)), " +
+        s"(3, 20.0, cast('2020-02-01' as timestamp)), " +
+        s"(4, 25.0, cast('2020-02-01' as timestamp)), " +
+        s"(5, 30.0, cast('2023-01-01' as timestamp))")
+
+    // In a left-outer join, and when the left side has larger stats, partially clustered
+    // distribution should kick in and pick the right hand side to replicate partitions.
+    Seq(true, false).foreach { pushDownValues =>
+      Seq(("true", 7), ("false", 5)).foreach {
+        case (enable, expected) =>
+          withSQLConf(
+            SQLConf.REQUIRE_ALL_CLUSTER_KEYS_FOR_CO_PARTITION.key -> false.toString,
+            SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> pushDownValues.toString,
+            SQLConf.V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED.key -> enable) {
+            val df = sql("SELECT id, name, i.price as purchase_price, p.price as sale_price " +
+                s"FROM testcat.ns.$items i LEFT JOIN testcat.ns.$purchases p " +
+                "ON i.id = p.item_id AND i.arrive_time = p.time " +
+                "ORDER BY id, purchase_price, sale_price")
+
+            val shuffles = collectShuffles(df.queryExecution.executedPlan)
+            if (pushDownValues) {
+              assert(shuffles.isEmpty, "should not contain any shuffle")
+              val scans = collectScans(df.queryExecution.executedPlan)
+              assert(scans.forall(_.inputRDD.partitions.length == expected),
+              s"Expected $expected but got ${scans.head.inputRDD.partitions.length}")
+            } else {
+              assert(shuffles.nonEmpty,
+                "should contain shuffle when not pushing down partition values")
+            }
+            checkAnswer(df, Seq(
+              Row(1, "aa", 40.0, null), Row(1, "aa", 41.0, null),
+              Row(2, "bb", 10.0, 20.0), Row(2, "bb", 15.0, null), Row(3, "cc", 15.5, 20.0)))
+          }
+      }
+    }
+  }
+
+  test("SPARK-42038: partially clustered: right outer join") {
+    val items_partitions = Array(identity("id"))
+    createTable(items, items_schema, items_partitions)
+    sql(s"INSERT INTO testcat.ns.$items VALUES " +
+        s"(1, 'aa', 40.0, cast('2020-01-01' as timestamp)), " +
+        s"(1, 'aa', 41.0, cast('2020-01-02' as timestamp)), " +
+        s"(2, 'bb', 10.0, cast('2020-01-01' as timestamp)), " +
+        s"(3, 'cc', 15.5, cast('2020-02-01' as timestamp))")
+
+    val purchases_partitions = Array(identity("item_id"))
+    createTable(purchases, purchases_schema, purchases_partitions)
+    sql(s"INSERT INTO testcat.ns.$purchases VALUES " +
+        s"(1, 45.0, cast('2020-01-01' as timestamp)), " +
+        s"(2, 15.0, cast('2020-01-01' as timestamp)), " +
+        s"(2, 20.0, cast('2020-01-01' as timestamp)), " +
+        s"(3, 20.0, cast('2020-02-01' as timestamp)), " +
+        s"(4, 25.0, cast('2020-02-01' as timestamp)), " +
+        s"(5, 30.0, cast('2023-01-01' as timestamp))")
+
+    // The left-hand side is picked as the side to replicate partitions based on stats, but since
+    // this is right outer join, partially clustered distribution won't kick in, and Spark should
+    // only push down partition values on both side.
+    Seq(true, false).foreach { pushDownValues =>
+      Seq(("true", 5), ("false", 5)).foreach {
+        case (enable, expected) =>
+          withSQLConf(
+            SQLConf.REQUIRE_ALL_CLUSTER_KEYS_FOR_CO_PARTITION.key -> false.toString,
+            SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> pushDownValues.toString,
+            SQLConf.V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED.key -> enable) {
+            val df = sql("SELECT id, name, i.price as purchase_price, p.price as sale_price " +
+                s"FROM testcat.ns.$items i RIGHT JOIN testcat.ns.$purchases p " +
+                "ON i.id = p.item_id AND i.arrive_time = p.time " +
+                "ORDER BY id, purchase_price, sale_price")
+
+            val shuffles = collectShuffles(df.queryExecution.executedPlan)
+            if (pushDownValues) {
+              assert(shuffles.isEmpty, "should not contain any shuffle")
+              val scans = collectScans(df.queryExecution.executedPlan)
+              assert(scans.map(_.inputRDD.partitions.length).toSet.size == 1)
+              assert(scans.forall(_.inputRDD.partitions.length == expected),
+                s"Expected $expected but got ${scans.head.inputRDD.partitions.length}")
+            } else {
+              assert(shuffles.nonEmpty,
+                "should contain shuffle when not pushing down partition values")
+            }
+            checkAnswer(df, Seq(
+              Row(null, null, null, 25.0), Row(null, null, null, 30.0),
+              Row(1, "aa", 40.0, 45.0),
+              Row(2, "bb", 10.0, 15.0), Row(2, "bb", 10.0, 20.0), Row(3, "cc", 15.5, 20.0)))
+          }
+      }
+    }
+  }
+
+  test("SPARK-42038: partially clustered: full outer join is not applicable") {
+    val items_partitions = Array(identity("id"))
+    createTable(items, items_schema, items_partitions)
+    sql(s"INSERT INTO testcat.ns.$items VALUES " +
+        s"(1, 'aa', 40.0, cast('2020-01-01' as timestamp)), " +
+        s"(1, 'aa', 41.0, cast('2020-01-02' as timestamp)), " +
+        s"(2, 'bb', 10.0, cast('2020-01-01' as timestamp)), " +
+        s"(3, 'cc', 15.5, cast('2020-01-01' as timestamp))")
+
+    val purchases_partitions = Array(identity("item_id"))
+    createTable(purchases, purchases_schema, purchases_partitions)
+    sql(s"INSERT INTO testcat.ns.$purchases VALUES " +
+        s"(1, 45.0, cast('2020-01-01' as timestamp)), " +
+        s"(2, 15.0, cast('2020-01-01' as timestamp)), " +
+        s"(2, 20.0, cast('2020-01-02' as timestamp)), " +
+        s"(3, 20.0, cast('2020-01-01' as timestamp)), " +
+        s"(4, 25.0, cast('2020-01-01' as timestamp)), " +
+        s"(5, 30.0, cast('2023-01-01' as timestamp))")
+
+    Seq(true, false).foreach { pushDownValues =>
+      Seq(("true", 5), ("false", 5)).foreach {
+        case (enable, expected) =>
+          withSQLConf(
+            SQLConf.REQUIRE_ALL_CLUSTER_KEYS_FOR_CO_PARTITION.key -> false.toString,
+            SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> pushDownValues.toString,
+            SQLConf.V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED.key -> enable) {
+            val df = sql("SELECT id, name, i.price as purchase_price, p.price as sale_price " +
+                s"FROM testcat.ns.$items i FULL OUTER JOIN testcat.ns.$purchases p " +
+                "ON i.id = p.item_id AND i.arrive_time = p.time " +
+                "ORDER BY id, purchase_price, sale_price")
+
+            val shuffles = collectShuffles(df.queryExecution.executedPlan)
+            if (pushDownValues) {
+              assert(shuffles.isEmpty, "should not contain any shuffle")
+              val scans = collectScans(df.queryExecution.executedPlan)
+              assert(scans.map(_.inputRDD.partitions.length).toSet.size == 1)
+              assert(scans.forall(_.inputRDD.partitions.length == expected),
+                s"Expected $expected but got ${scans.head.inputRDD.partitions.length}")
+            } else {
+              assert(shuffles.nonEmpty,
+                "should contain shuffle when not pushing down partition values")
+            }
+            checkAnswer(df, Seq(
+              Row(null, null, null, 20.0), Row(null, null, null, 25.0), Row(null, null, null, 30.0),
+              Row(1, "aa", 40.0, 45.0), Row(1, "aa", 41.0, null),
+              Row(2, "bb", 10.0, 15.0), Row(3, "cc", 15.5, 20.0)))
+          }
+      }
+    }
+  }
+
   test("data source partitioning + dynamic partition filtering") {
     withSQLConf(
-        SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "10MB",
+        SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1",
         SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "false",
         SQLConf.DYNAMIC_PARTITION_PRUNING_ENABLED.key -> "true",
         SQLConf.DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST_ONLY.key -> "false",
@@ -580,4 +975,66 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase {
       }
     }
   }
+
+  test("SPARK-42038: partially clustered: with dynamic partition filtering") {
+    val items_partitions = Array(identity("id"))
+    createTable(items, items_schema, items_partitions)
+    sql(s"INSERT INTO testcat.ns.$items VALUES " +
+        s"(1, 'aa', 40.0, cast('2020-01-01' as timestamp)), " +
+        s"(1, 'aa', 41.0, cast('2020-01-15' as timestamp)), " +
+        s"(2, 'bb', 10.0, cast('2020-01-01' as timestamp)), " +
+        s"(2, 'bb', 10.5, cast('2020-01-01' as timestamp)), " +
+        s"(3, 'cc', 15.5, cast('2020-02-01' as timestamp)), " +
+        s"(4, 'dd', 18.0, cast('2023-01-01' as timestamp))")
+
+    val purchases_partitions = Array(identity("item_id"))
+    createTable(purchases, purchases_schema, purchases_partitions)
+    sql(s"INSERT INTO testcat.ns.$purchases VALUES " +
+        s"(1, 42.0, cast('2020-01-01' as timestamp)), " +
+        s"(1, 44.0, cast('2020-01-15' as timestamp)), " +
+        s"(1, 45.0, cast('2020-01-15' as timestamp)), " +
+        s"(1, 50.0, cast('2020-01-15' as timestamp)), " +
+        s"(1, 55.0, cast('2020-01-15' as timestamp)), " +
+        s"(1, 60.0, cast('2020-01-15' as timestamp)), " +
+        s"(1, 65.0, cast('2020-01-15' as timestamp)), " +
+        s"(2, 11.0, cast('2020-01-01' as timestamp)), " +
+        s"(3, 19.5, cast('2020-02-01' as timestamp)), " +
+        s"(5, 25.0, cast('2023-01-01' as timestamp)), " +
+        s"(5, 26.0, cast('2023-01-01' as timestamp)), " +
+        s"(5, 28.0, cast('2023-01-01' as timestamp)), " +
+        s"(6, 50.0, cast('2023-02-01' as timestamp)), " +
+        s"(6, 50.0, cast('2023-02-01' as timestamp))")
+
+    Seq(true, false).foreach { pushDownValues =>
+      Seq(("true", 15), ("false", 6)).foreach {
+        case (enable, expected) =>
+          withSQLConf(
+              SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1",
+              SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "false",
+              SQLConf.DYNAMIC_PARTITION_PRUNING_ENABLED.key -> "true",
+              SQLConf.DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST_ONLY.key -> "false",
+              SQLConf.DYNAMIC_PARTITION_PRUNING_FALLBACK_FILTER_RATIO.key -> "10",
+              SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> pushDownValues.toString,
+              SQLConf.V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED.key -> enable) {
+
+            // storage-partitioned join should kick in and fill the missing partitions & splits
+            // after dynamic filtering with empty partitions & splits, respectively.
+            val df = sql(s"SELECT sum(p.price) from " +
+                s"testcat.ns.$purchases p, testcat.ns.$items i WHERE " +
+                s"p.item_id = i.id AND p.price < 45.0")
+
+            checkAnswer(df, Seq(Row(213.5)))
+            val shuffles = collectShuffles(df.queryExecution.executedPlan)
+            if (pushDownValues) {
+              assert(shuffles.isEmpty, "should not add shuffle for both sides of the join")
+              val scans = collectScans(df.queryExecution.executedPlan)
+              assert(scans.forall(_.inputRDD.partitions.length == expected))
+            } else {
+              assert(shuffles.nonEmpty,
+                "should contain shuffle when not pushing down partition values")
+            }
+          }
+      }
+    }
+  }
 }
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitionsSuite.scala
index 8fd07db1ce9..7dbbb796e84 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitionsSuite.scala
@@ -165,7 +165,7 @@ class PruneFileSourcePartitionsSuite extends PrunePartitionSuiteBase with Shared
   override def getScanExecPartitionSize(plan: SparkPlan): Long = {
     plan.collectFirst {
       case p: FileSourceScanExec => p.selectedPartitions.length
-      case BatchScanExec(_, scan: FileScan, _, _, _, _, _) =>
+      case BatchScanExec(_, scan: FileScan, _, _, _, _, _, _, _) =>
         scan.fileIndex.listFiles(scan.partitionFilters, scan.dataFilters).length
     }.get
   }
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/PrunePartitionSuiteBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/PrunePartitionSuiteBase.scala
index ed339244372..9a61e6517f7 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/PrunePartitionSuiteBase.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/PrunePartitionSuiteBase.scala
@@ -95,7 +95,7 @@ abstract class PrunePartitionSuiteBase extends StatisticsCollectionTestBase {
     assert(getScanExecPartitionSize(plan) == expectedPartitionCount)
 
     val collectFn: PartialFunction[SparkPlan, Seq[Expression]] = collectPartitionFiltersFn orElse {
-      case BatchScanExec(_, scan: FileScan, _, _, _, _, _) => scan.partitionFilters
+      case BatchScanExec(_, scan: FileScan, _, _, _, _, _, _, _) => scan.partitionFilters
     }
     val pushedDownPartitionFilters = plan.collectFirst(collectFn)
       .map(exps => exps.filterNot(e => e.isInstanceOf[IsNotNull]))
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcV2SchemaPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcV2SchemaPruningSuite.scala
index 9098c5e87a6..1fba772f5a8 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcV2SchemaPruningSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcV2SchemaPruningSuite.scala
@@ -42,7 +42,7 @@ class OrcV2SchemaPruningSuite extends SchemaPruningSuite with AdaptiveSparkPlanH
   override def checkScanSchemata(df: DataFrame, expectedSchemaCatalogStrings: String*): Unit = {
     val fileSourceScanSchemata =
       collect(df.queryExecution.executedPlan) {
-        case BatchScanExec(_, scan: OrcScan, _, _, _, _, _) => scan.readDataSchema
+        case BatchScanExec(_, scan: OrcScan, _, _, _, _, _, _, _) => scan.readDataSchema
       }
     assert(fileSourceScanSchemata.size === expectedSchemaCatalogStrings.size,
       s"Found ${fileSourceScanSchemata.size} file sources in dataframe, " +
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala
index 844037339ab..df1ddb7d9cd 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala
@@ -818,7 +818,7 @@ class EnsureRequirementsSuite extends SharedSparkSession {
     plan2 = DummySparkPlan(
       outputPartitioning = PartitioningCollection(Seq(
         KeyGroupedPartitioning(bucket(4, exprA) :: bucket(16, exprC) :: Nil, 4),
-        HashPartitioning(exprA :: exprC :: Nil, 4))
+        KeyGroupedPartitioning(bucket(4, exprA) :: bucket(16, exprC) :: Nil, 4))
       )
     )
     smjExec = SortMergeJoinExec(


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org