You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2023/10/23 02:48:28 UTC

[spark] branch branch-3.5 updated: [SPARK-45616][CORE] Avoid ParVector, which does not propagate ThreadLocals or SparkSession

This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch branch-3.5
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.5 by this push:
     new 75a38b9024a [SPARK-45616][CORE] Avoid ParVector, which does not propagate ThreadLocals or SparkSession
75a38b9024a is described below

commit 75a38b9024af3c9cfd85e916c46359f7e7315c87
Author: Ankur Dave <an...@gmail.com>
AuthorDate: Mon Oct 23 10:47:42 2023 +0800

    [SPARK-45616][CORE] Avoid ParVector, which does not propagate ThreadLocals or SparkSession
    
    ### What changes were proposed in this pull request?
    `CastSuiteBase` and `ExpressionInfoSuite` use `ParVector.foreach()` to run Spark SQL queries in parallel. They incorrectly assume that each parallel operation will inherit the main thread’s active SparkSession. This is only true when these parallel operations run in freshly-created threads. However, when other code has already run some parallel operations before Spark was started, then there may be existing threads that do not have an active SparkSession. In that case, these tests fai [...]
    
    The fix is to use the existing method `ThreadUtils.parmap()`. This method creates fresh threads that inherit the current active SparkSession, and it propagates the Spark ThreadLocals.
    
    This PR also adds a scalastyle warning against use of ParVector.
    
    ### Why are the changes needed?
    This change makes `CastSuiteBase` and `ExpressionInfoSuite` less brittle to future changes that may run parallel operations during test startup.
    
    ### Does this PR introduce _any_ user-facing change?
    No.
    
    ### How was this patch tested?
    Reproduced the test failures by running a ParVector operation before Spark starts. Verified that this PR fixes the test failures in this condition.
    
    ```scala
      protected override def beforeAll(): Unit = {
        // Run a ParVector operation before initializing the SparkSession. This starts some Scala
        // execution context threads that have no active SparkSession. These threads will be reused for
        // later ParVector operations, reproducing SPARK-45616.
        new ParVector((0 until 100).toVector).foreach { _ => }
    
        super.beforeAll()
      }
    ```
    
    ### Was this patch authored or co-authored using generative AI tooling?
    No.
    
    Closes #43466 from ankurdave/SPARK-45616.
    
    Authored-by: Ankur Dave <an...@gmail.com>
    Signed-off-by: Wenchen Fan <we...@databricks.com>
    (cherry picked from commit 376de8a502fca6b46d7f21560a60024d643144ea)
    Signed-off-by: Wenchen Fan <we...@databricks.com>
---
 core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala      |  2 ++
 core/src/main/scala/org/apache/spark/util/ThreadUtils.scala  |  4 ++++
 scalastyle-config.xml                                        | 12 ++++++++++++
 .../spark/sql/catalyst/expressions/CastSuiteBase.scala       |  9 ++++++---
 .../scala/org/apache/spark/sql/execution/command/ddl.scala   |  2 ++
 .../apache/spark/sql/expressions/ExpressionInfoSuite.scala   | 11 ++++++-----
 .../main/scala/org/apache/spark/streaming/DStreamGraph.scala |  4 ++++
 .../apache/spark/streaming/util/FileBasedWriteAheadLog.scala |  2 ++
 8 files changed, 38 insertions(+), 8 deletions(-)

diff --git a/core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala b/core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala
index 0a930234437..3c1451a0185 100644
--- a/core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala
@@ -76,8 +76,10 @@ class UnionRDD[T: ClassTag](
 
   override def getPartitions: Array[Partition] = {
     val parRDDs = if (isPartitionListingParallel) {
+      // scalastyle:off parvector
       val parArray = new ParVector(rdds.toVector)
       parArray.tasksupport = UnionRDD.partitionEvalTaskSupport
+      // scalastyle:on parvector
       parArray
     } else {
       rdds
diff --git a/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala b/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala
index 16d7de56c39..2d3d6ec89ff 100644
--- a/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala
+++ b/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala
@@ -363,6 +363,10 @@ private[spark] object ThreadUtils {
    * Comparing to the map() method of Scala parallel collections, this method can be interrupted
    * at any time. This is useful on canceling of task execution, for example.
    *
+   * Functions are guaranteed to be executed in freshly-created threads that inherit the calling
+   * thread's Spark thread-local variables. These threads also inherit the calling thread's active
+   * SparkSession.
+   *
    * @param in - the input collection which should be transformed in parallel.
    * @param prefix - the prefix assigned to the underlying thread pool.
    * @param maxThreads - maximum number of thread can be created during execution.
diff --git a/scalastyle-config.xml b/scalastyle-config.xml
index 74e8480deaf..0ccd937e72e 100644
--- a/scalastyle-config.xml
+++ b/scalastyle-config.xml
@@ -227,6 +227,18 @@ This file is divided into 3 sections:
     ]]></customMessage>
   </check>
 
+  <check customId="parvector" level="error" class="org.scalastyle.file.RegexChecker" enabled="true">
+    <parameters><parameter name="regex">new.*ParVector</parameter></parameters>
+    <customMessage><![CDATA[
+      Are you sure you want to create a ParVector? It will not automatically propagate Spark ThreadLocals or the
+      active SparkSession for the submitted tasks. In most cases, you should use ThreadUtils.parmap instead.
+      If you must use ParVector, then wrap your creation of the ParVector with
+      // scalastyle:off parvector
+      ...ParVector...
+      // scalastyle:on parvector
+    ]]></customMessage>
+  </check>
+
   <check customId="caselocale" level="error" class="org.scalastyle.file.RegexChecker" enabled="true">
     <parameters><parameter name="regex">(\.toUpperCase|\.toLowerCase)(?!(\(|\(Locale.ROOT\)))</parameter></parameters>
     <customMessage><![CDATA[
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuiteBase.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuiteBase.scala
index 0172fd9b3e4..1ce311a5544 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuiteBase.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuiteBase.scala
@@ -22,8 +22,6 @@ import java.time.{Duration, LocalDate, LocalDateTime, Period}
 import java.time.temporal.ChronoUnit
 import java.util.{Calendar, Locale, TimeZone}
 
-import scala.collection.parallel.immutable.ParVector
-
 import org.apache.spark.SparkFunSuite
 import org.apache.spark.sql.Row
 import org.apache.spark.sql.catalyst.InternalRow
@@ -42,6 +40,7 @@ import org.apache.spark.sql.types.DayTimeIntervalType.{DAY, HOUR, MINUTE, SECOND
 import org.apache.spark.sql.types.UpCastRule.numericPrecedence
 import org.apache.spark.sql.types.YearMonthIntervalType.{MONTH, YEAR}
 import org.apache.spark.unsafe.types.UTF8String
+import org.apache.spark.util.ThreadUtils
 
 /**
  * Common test suite for [[Cast]] with ansi mode on and off. It only includes test cases that work
@@ -126,7 +125,11 @@ abstract class CastSuiteBase extends SparkFunSuite with ExpressionEvalHelper {
   }
 
   test("cast string to timestamp") {
-    new ParVector(ALL_TIMEZONES.toVector).foreach { zid =>
+    ThreadUtils.parmap(
+      ALL_TIMEZONES,
+      prefix = "CastSuiteBase-cast-string-to-timestamp",
+      maxThreads = Runtime.getRuntime.availableProcessors
+    ) { zid =>
       def checkCastStringToTimestamp(str: String, expected: Timestamp): Unit = {
         checkEvaluation(cast(Literal(str), TimestampType, Option(zid.getId)), expected)
       }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala
index a8f7cdb2600..bb8fea71019 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala
@@ -755,8 +755,10 @@ case class RepairTableCommand(
     val statusPar: Seq[FileStatus] =
       if (partitionNames.length > 1 && statuses.length > threshold || partitionNames.length > 2) {
         // parallelize the list of partitions here, then we can have better parallelism later.
+        // scalastyle:off parvector
         val parArray = new ParVector(statuses.toVector)
         parArray.tasksupport = evalTaskSupport
+        // scalastyle:on parvector
         parArray.seq
       } else {
         statuses
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/expressions/ExpressionInfoSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/expressions/ExpressionInfoSuite.scala
index 4dd93983e87..a02137a56aa 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/expressions/ExpressionInfoSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/expressions/ExpressionInfoSuite.scala
@@ -17,8 +17,6 @@
 
 package org.apache.spark.sql.expressions
 
-import scala.collection.parallel.immutable.ParVector
-
 import org.apache.spark.SparkFunSuite
 import org.apache.spark.sql.catalyst.{FunctionIdentifier, InternalRow}
 import org.apache.spark.sql.catalyst.expressions._
@@ -26,7 +24,7 @@ import org.apache.spark.sql.execution.HiveResult.hiveResultString
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.test.SharedSparkSession
 import org.apache.spark.tags.SlowSQLTest
-import org.apache.spark.util.Utils
+import org.apache.spark.util.{ThreadUtils, Utils}
 
 @SlowSQLTest
 class ExpressionInfoSuite extends SparkFunSuite with SharedSparkSession {
@@ -197,8 +195,11 @@ class ExpressionInfoSuite extends SparkFunSuite with SharedSparkSession {
       // The encrypt expression includes a random initialization vector to its encrypted result
       classOf[AesEncrypt].getName)
 
-    val parFuncs = new ParVector(spark.sessionState.functionRegistry.listFunction().toVector)
-    parFuncs.foreach { funcId =>
+    ThreadUtils.parmap(
+      spark.sessionState.functionRegistry.listFunction(),
+      prefix = "ExpressionInfoSuite-check-outputs-of-expression-examples",
+      maxThreads = Runtime.getRuntime.availableProcessors
+    ) { funcId =>
       // Examples can change settings. We clone the session to prevent tests clashing.
       val clonedSpark = spark.cloneSession()
       // Coalescing partitions can change result order, so disable it.
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala b/streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala
index 43aaa7e1eea..a8f55c8b4d6 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala
@@ -52,7 +52,9 @@ final private[streaming] class DStreamGraph extends Serializable with Logging {
       outputStreams.foreach(_.validateAtStart())
       numReceivers = inputStreams.count(_.isInstanceOf[ReceiverInputDStream[_]])
       inputStreamNameAndID = inputStreams.map(is => (is.name, is.id)).toSeq
+      // scalastyle:off parvector
       new ParVector(inputStreams.toVector).foreach(_.start())
+      // scalastyle:on parvector
     }
   }
 
@@ -62,7 +64,9 @@ final private[streaming] class DStreamGraph extends Serializable with Logging {
 
   def stop(): Unit = {
     this.synchronized {
+      // scalastyle:off parvector
       new ParVector(inputStreams.toVector).foreach(_.stop())
+      // scalastyle:on parvector
     }
   }
 
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala
index d1f9dfb7913..4e65bc75e43 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala
@@ -314,8 +314,10 @@ private[streaming] object FileBasedWriteAheadLog {
     val groupSize = taskSupport.parallelismLevel.max(8)
 
     source.grouped(groupSize).flatMap { group =>
+      // scalastyle:off parvector
       val parallelCollection = new ParVector(group.toVector)
       parallelCollection.tasksupport = taskSupport
+      // scalastyle:on parvector
       parallelCollection.map(handler)
     }.flatten
   }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org