You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by hv...@apache.org on 2022/06/06 19:51:21 UTC
[spark] branch master updated: [SPARK-39391][CORE] Reuse Partitioner classes
This is an automated email from the ASF dual-hosted git repository.
hvanhovell pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new c548e593019 [SPARK-39391][CORE] Reuse Partitioner classes
c548e593019 is described below
commit c548e59301941a40ff2d07590645bcd24280a550
Author: Herman van Hovell <he...@databricks.com>
AuthorDate: Mon Jun 6 15:50:50 2022 -0400
[SPARK-39391][CORE] Reuse Partitioner classes
### What changes were proposed in this pull request?
This PR creates two new `Partitioner` classes:
- `ConstantPartitioner`: This moves all tuples in a RDD into a single partition. This replaces two anonymous partitioners in `RDD` and `ShuffleExchangeExec`.
- `PartitionIdPassthrough`: This is a dummy partitioner that passes through keys when they already have been computed. This is actually not a new class, it was moved from `ShuffleRowRDD.scala` to core. This replaces two anonymous partitioners in `BlockMatrix` and `RDD`.
### Why are the changes needed?
Less code.
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
Existing tests.
Closes #36779 from hvanhovell/SPARK-39391.
Authored-by: Herman van Hovell <he...@databricks.com>
Signed-off-by: Herman van Hovell <he...@databricks.com>
---
core/src/main/scala/org/apache/spark/Partitioner.scala | 16 ++++++++++++++++
core/src/main/scala/org/apache/spark/rdd/RDD.scala | 8 +-------
.../spark/mllib/linalg/distributed/BlockMatrix.scala | 8 +++-----
.../org/apache/spark/sql/execution/ShuffledRowRDD.scala | 8 --------
.../sql/execution/exchange/ShuffleExchangeExec.scala | 15 ++++-----------
5 files changed, 24 insertions(+), 31 deletions(-)
diff --git a/core/src/main/scala/org/apache/spark/Partitioner.scala b/core/src/main/scala/org/apache/spark/Partitioner.scala
index a0cba8ab13f..5dffba2ee8e 100644
--- a/core/src/main/scala/org/apache/spark/Partitioner.scala
+++ b/core/src/main/scala/org/apache/spark/Partitioner.scala
@@ -129,6 +129,22 @@ class HashPartitioner(partitions: Int) extends Partitioner {
override def hashCode: Int = numPartitions
}
+/**
+ * A dummy partitioner for use with records whose partition ids have been pre-computed (i.e. for
+ * use on RDDs of (Int, Row) pairs where the Int is a partition id in the expected range).
+ */
+private[spark] class PartitionIdPassthrough(override val numPartitions: Int) extends Partitioner {
+ override def getPartition(key: Any): Int = key.asInstanceOf[Int]
+}
+
+/**
+ * A [[org.apache.spark.Partitioner]] that partitions all records into a single partition.
+ */
+private[spark] class ConstantPartitioner extends Partitioner {
+ override def numPartitions: Int = 1
+ override def getPartition(key: Any): Int = 0
+}
+
/**
* A [[org.apache.spark.Partitioner]] that partitions sortable records by range into roughly
* equal ranges. The ranges are determined by sampling the content of the RDD passed in.
diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
index 89397b8aa69..b7284d25122 100644
--- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
@@ -1249,18 +1249,12 @@ abstract class RDD[T: ClassTag](
}.foldByKey(zeroValue, new HashPartitioner(curNumPartitions))(cleanCombOp).values
}
if (finalAggregateOnExecutor && partiallyAggregated.partitions.length > 1) {
- // define a new partitioner that results in only 1 partition
- val constantPartitioner = new Partitioner {
- override def numPartitions: Int = 1
-
- override def getPartition(key: Any): Int = 0
- }
// map the partially aggregated rdd into a key-value rdd
// do the computation in the single executor with one partition
// get the new RDD[U]
partiallyAggregated = partiallyAggregated
.map(v => (0.toByte, v))
- .foldByKey(zeroValue, constantPartitioner)(cleanCombOp)
+ .foldByKey(zeroValue, new ConstantPartitioner)(cleanCombOp)
.values
}
val copiedZeroValue = Utils.clone(zeroValue, sc.env.closureSerializer.newInstance())
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrix.scala
index 452bbbe5f46..2b4333fe0fd 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrix.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrix.scala
@@ -20,7 +20,7 @@ package org.apache.spark.mllib.linalg.distributed
import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, Matrix => BM}
import scala.collection.mutable.ArrayBuffer
-import org.apache.spark.{Partitioner, SparkException}
+import org.apache.spark.{Partitioner, PartitionIdPassthrough, SparkException}
import org.apache.spark.annotation.Since
import org.apache.spark.internal.Logging
import org.apache.spark.mllib.linalg._
@@ -520,10 +520,8 @@ class BlockMatrix @Since("1.3.0") (
val destinations = rightDestinations.getOrElse((blockRowIndex, blockColIndex), Set.empty)
destinations.map(j => (j, (blockRowIndex, blockColIndex, block)))
}
- val intermediatePartitioner = new Partitioner {
- override def numPartitions: Int = resultPartitioner.numPartitions * numMidDimSplits
- override def getPartition(key: Any): Int = key.asInstanceOf[Int]
- }
+ val intermediatePartitioner = new PartitionIdPassthrough(
+ resultPartitioner.numPartitions * numMidDimSplits)
val newBlocks = flatA.cogroup(flatB, intermediatePartitioner).flatMap { case (pId, (a, b)) =>
a.flatMap { case (leftRowIndex, leftColIndex, leftBlock) =>
b.filter(_._1 == leftColIndex).map { case (rightRowIndex, rightColIndex, rightBlock) =>
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ShuffledRowRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ShuffledRowRDD.scala
index 47d61196fe8..7e7100338dd 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ShuffledRowRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ShuffledRowRDD.scala
@@ -70,14 +70,6 @@ case class CoalescedMapperPartitionSpec(
private final case class ShuffledRowRDDPartition(
index: Int, spec: ShufflePartitionSpec) extends Partition
-/**
- * A dummy partitioner for use with records whose partition ids have been pre-computed (i.e. for
- * use on RDDs of (Int, Row) pairs where the Int is a partition id in the expected range).
- */
-private class PartitionIdPassthrough(override val numPartitions: Int) extends Partitioner {
- override def getPartition(key: Any): Int = key.asInstanceOf[Int]
-}
-
/**
* A Partitioner that might group together one or more partitions from the parent.
*
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala
index f3eb5636bb9..907198ad5d2 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala
@@ -268,12 +268,9 @@ object ShuffleExchangeExec {
val part: Partitioner = newPartitioning match {
case RoundRobinPartitioning(numPartitions) => new HashPartitioner(numPartitions)
case HashPartitioning(_, n) =>
- new Partitioner {
- override def numPartitions: Int = n
- // For HashPartitioning, the partitioning key is already a valid partition ID, as we use
- // `HashPartitioning.partitionIdExpression` to produce partitioning key.
- override def getPartition(key: Any): Int = key.asInstanceOf[Int]
- }
+ // For HashPartitioning, the partitioning key is already a valid partition ID, as we use
+ // `HashPartitioning.partitionIdExpression` to produce partitioning key.
+ new PartitionIdPassthrough(n)
case RangePartitioning(sortingExpressions, numPartitions) =>
// Extract only fields used for sorting to avoid collecting large fields that does not
// affect sorting result when deciding partition bounds in RangePartitioner
@@ -295,11 +292,7 @@ object ShuffleExchangeExec {
rddForSampling,
ascending = true,
samplePointsPerPartitionHint = SQLConf.get.rangeExchangeSampleSizePerPartition)
- case SinglePartition =>
- new Partitioner {
- override def numPartitions: Int = 1
- override def getPartition(key: Any): Int = 0
- }
+ case SinglePartition => new ConstantPartitioner
case _ => throw new IllegalStateException(s"Exchange not implemented for $newPartitioning")
// TODO: Handle BroadcastPartitioning.
}
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org