You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by hv...@apache.org on 2022/06/06 19:51:21 UTC

[spark] branch master updated: [SPARK-39391][CORE] Reuse Partitioner classes

This is an automated email from the ASF dual-hosted git repository.

hvanhovell pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new c548e593019 [SPARK-39391][CORE] Reuse Partitioner classes
c548e593019 is described below

commit c548e59301941a40ff2d07590645bcd24280a550
Author: Herman van Hovell <he...@databricks.com>
AuthorDate: Mon Jun 6 15:50:50 2022 -0400

    [SPARK-39391][CORE] Reuse Partitioner classes
    
    ### What changes were proposed in this pull request?
    This PR creates two new `Partitioner` classes:
    - `ConstantPartitioner`: This moves all tuples in a RDD into a single partition. This replaces two anonymous partitioners in `RDD` and `ShuffleExchangeExec`.
    - `PartitionIdPassthrough`: This is a dummy partitioner that passes through keys when they already have been computed. This is actually not a new class, it was moved from `ShuffleRowRDD.scala` to core. This replaces two anonymous partitioners in `BlockMatrix` and `RDD`.
    
    ### Why are the changes needed?
    Less code.
    
    ### Does this PR introduce _any_ user-facing change?
    No
    
    ### How was this patch tested?
    Existing tests.
    
    Closes #36779 from hvanhovell/SPARK-39391.
    
    Authored-by: Herman van Hovell <he...@databricks.com>
    Signed-off-by: Herman van Hovell <he...@databricks.com>
---
 core/src/main/scala/org/apache/spark/Partitioner.scala   | 16 ++++++++++++++++
 core/src/main/scala/org/apache/spark/rdd/RDD.scala       |  8 +-------
 .../spark/mllib/linalg/distributed/BlockMatrix.scala     |  8 +++-----
 .../org/apache/spark/sql/execution/ShuffledRowRDD.scala  |  8 --------
 .../sql/execution/exchange/ShuffleExchangeExec.scala     | 15 ++++-----------
 5 files changed, 24 insertions(+), 31 deletions(-)

diff --git a/core/src/main/scala/org/apache/spark/Partitioner.scala b/core/src/main/scala/org/apache/spark/Partitioner.scala
index a0cba8ab13f..5dffba2ee8e 100644
--- a/core/src/main/scala/org/apache/spark/Partitioner.scala
+++ b/core/src/main/scala/org/apache/spark/Partitioner.scala
@@ -129,6 +129,22 @@ class HashPartitioner(partitions: Int) extends Partitioner {
   override def hashCode: Int = numPartitions
 }
 
+/**
+ * A dummy partitioner for use with records whose partition ids have been pre-computed (i.e. for
+ * use on RDDs of (Int, Row) pairs where the Int is a partition id in the expected range).
+ */
+private[spark] class PartitionIdPassthrough(override val numPartitions: Int) extends Partitioner {
+  override def getPartition(key: Any): Int = key.asInstanceOf[Int]
+}
+
+/**
+ * A [[org.apache.spark.Partitioner]] that partitions all records into a single partition.
+ */
+private[spark] class ConstantPartitioner extends Partitioner {
+  override def numPartitions: Int = 1
+  override def getPartition(key: Any): Int = 0
+}
+
 /**
  * A [[org.apache.spark.Partitioner]] that partitions sortable records by range into roughly
  * equal ranges. The ranges are determined by sampling the content of the RDD passed in.
diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
index 89397b8aa69..b7284d25122 100644
--- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
@@ -1249,18 +1249,12 @@ abstract class RDD[T: ClassTag](
         }.foldByKey(zeroValue, new HashPartitioner(curNumPartitions))(cleanCombOp).values
       }
       if (finalAggregateOnExecutor && partiallyAggregated.partitions.length > 1) {
-        // define a new partitioner that results in only 1 partition
-        val constantPartitioner = new Partitioner {
-          override def numPartitions: Int = 1
-
-          override def getPartition(key: Any): Int = 0
-        }
         // map the partially aggregated rdd into a key-value rdd
         // do the computation in the single executor with one partition
         // get the new RDD[U]
         partiallyAggregated = partiallyAggregated
           .map(v => (0.toByte, v))
-          .foldByKey(zeroValue, constantPartitioner)(cleanCombOp)
+          .foldByKey(zeroValue, new ConstantPartitioner)(cleanCombOp)
           .values
       }
       val copiedZeroValue = Utils.clone(zeroValue, sc.env.closureSerializer.newInstance())
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrix.scala
index 452bbbe5f46..2b4333fe0fd 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrix.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrix.scala
@@ -20,7 +20,7 @@ package org.apache.spark.mllib.linalg.distributed
 import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, Matrix => BM}
 import scala.collection.mutable.ArrayBuffer
 
-import org.apache.spark.{Partitioner, SparkException}
+import org.apache.spark.{Partitioner, PartitionIdPassthrough, SparkException}
 import org.apache.spark.annotation.Since
 import org.apache.spark.internal.Logging
 import org.apache.spark.mllib.linalg._
@@ -520,10 +520,8 @@ class BlockMatrix @Since("1.3.0") (
         val destinations = rightDestinations.getOrElse((blockRowIndex, blockColIndex), Set.empty)
         destinations.map(j => (j, (blockRowIndex, blockColIndex, block)))
       }
-      val intermediatePartitioner = new Partitioner {
-        override def numPartitions: Int = resultPartitioner.numPartitions * numMidDimSplits
-        override def getPartition(key: Any): Int = key.asInstanceOf[Int]
-      }
+      val intermediatePartitioner = new PartitionIdPassthrough(
+        resultPartitioner.numPartitions * numMidDimSplits)
       val newBlocks = flatA.cogroup(flatB, intermediatePartitioner).flatMap { case (pId, (a, b)) =>
         a.flatMap { case (leftRowIndex, leftColIndex, leftBlock) =>
           b.filter(_._1 == leftColIndex).map { case (rightRowIndex, rightColIndex, rightBlock) =>
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ShuffledRowRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ShuffledRowRDD.scala
index 47d61196fe8..7e7100338dd 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ShuffledRowRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ShuffledRowRDD.scala
@@ -70,14 +70,6 @@ case class CoalescedMapperPartitionSpec(
 private final case class ShuffledRowRDDPartition(
   index: Int, spec: ShufflePartitionSpec) extends Partition
 
-/**
- * A dummy partitioner for use with records whose partition ids have been pre-computed (i.e. for
- * use on RDDs of (Int, Row) pairs where the Int is a partition id in the expected range).
- */
-private class PartitionIdPassthrough(override val numPartitions: Int) extends Partitioner {
-  override def getPartition(key: Any): Int = key.asInstanceOf[Int]
-}
-
 /**
  * A Partitioner that might group together one or more partitions from the parent.
  *
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala
index f3eb5636bb9..907198ad5d2 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala
@@ -268,12 +268,9 @@ object ShuffleExchangeExec {
     val part: Partitioner = newPartitioning match {
       case RoundRobinPartitioning(numPartitions) => new HashPartitioner(numPartitions)
       case HashPartitioning(_, n) =>
-        new Partitioner {
-          override def numPartitions: Int = n
-          // For HashPartitioning, the partitioning key is already a valid partition ID, as we use
-          // `HashPartitioning.partitionIdExpression` to produce partitioning key.
-          override def getPartition(key: Any): Int = key.asInstanceOf[Int]
-        }
+        // For HashPartitioning, the partitioning key is already a valid partition ID, as we use
+        // `HashPartitioning.partitionIdExpression` to produce partitioning key.
+        new PartitionIdPassthrough(n)
       case RangePartitioning(sortingExpressions, numPartitions) =>
         // Extract only fields used for sorting to avoid collecting large fields that does not
         // affect sorting result when deciding partition bounds in RangePartitioner
@@ -295,11 +292,7 @@ object ShuffleExchangeExec {
           rddForSampling,
           ascending = true,
           samplePointsPerPartitionHint = SQLConf.get.rangeExchangeSampleSizePerPartition)
-      case SinglePartition =>
-        new Partitioner {
-          override def numPartitions: Int = 1
-          override def getPartition(key: Any): Int = 0
-        }
+      case SinglePartition => new ConstantPartitioner
       case _ => throw new IllegalStateException(s"Exchange not implemented for $newPartitioning")
       // TODO: Handle BroadcastPartitioning.
     }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org