You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2020/02/17 18:29:43 UTC

[spark] branch branch-3.0 updated: [SPARK-22590][SQL] Copy sparkContext.localproperties to child thread in BroadcastExchangeExec.executionContext

This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch branch-3.0
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.0 by this push:
     new 08131bb  [SPARK-22590][SQL] Copy sparkContext.localproperties to child thread in BroadcastExchangeExec.executionContext
08131bb is described below

commit 08131bbcfb7257caf146eba4af65ba0e407d34dd
Author: Ajith <aj...@gmail.com>
AuthorDate: Tue Feb 18 02:26:52 2020 +0800

    [SPARK-22590][SQL] Copy sparkContext.localproperties to child thread in BroadcastExchangeExec.executionContext
    
    ### What changes were proposed in this pull request?
    In `org.apache.spark.sql.execution.exchange.BroadcastExchangeExec#relationFuture` make a copy of `org.apache.spark.SparkContext#localProperties` and pass it to the broadcast execution thread in `org.apache.spark.sql.execution.exchange.BroadcastExchangeExec#executionContext`
    
    ### Why are the changes needed?
    When executing `BroadcastExchangeExec`, the relationFuture is evaluated via a separate thread. The threads inherit the `localProperties` from `sparkContext` as they are the child threads.
    These threads are created in the executionContext (thread pools). Each Thread pool has a default `keepAliveSeconds` of 60 seconds for idle threads.
    Scenarios where the thread pool has threads which are idle and reused for a subsequent new query, the thread local properties will not be inherited from spark context (thread properties are inherited only on thread creation) hence end up having old or no properties set. This will cause taskset properties to be missing when properties are transferred by child thread via `sparkContext.runJob/submitJob`
    
    ### Does this PR introduce any user-facing change?
    No
    
    ### How was this patch tested?
    Added UT
    
    Closes #27266 from ajithme/broadcastlocalprop.
    
    Authored-by: Ajith <aj...@gmail.com>
    Signed-off-by: Wenchen Fan <we...@databricks.com>
    (cherry picked from commit 2854091d12d670b014c41713e72153856f4d3f6a)
    Signed-off-by: Wenchen Fan <we...@databricks.com>
---
 .../scala/org/apache/spark/util/ThreadUtils.scala  | 17 +++++++++++++
 .../apache/spark/sql/execution/SQLExecution.scala  | 10 ++++----
 .../sql/execution/basicPhysicalOperators.scala     |  5 ++--
 .../execution/exchange/BroadcastExchangeExec.scala | 16 ++++---------
 .../sql/internal/ExecutorSideSQLConfSuite.scala    | 28 ++++++++++++++++++++++
 5 files changed, 56 insertions(+), 20 deletions(-)

diff --git a/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala b/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala
index de39e4b..e7872bb 100644
--- a/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala
+++ b/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala
@@ -18,6 +18,7 @@
 package org.apache.spark.util
 
 import java.util.concurrent._
+import java.util.concurrent.{Future => JFuture}
 import java.util.concurrent.locks.ReentrantLock
 
 import scala.concurrent.{Awaitable, ExecutionContext, ExecutionContextExecutor, Future}
@@ -304,6 +305,22 @@ private[spark] object ThreadUtils {
   }
   // scalastyle:on awaitresult
 
+  @throws(classOf[SparkException])
+  def awaitResult[T](future: JFuture[T], atMost: Duration): T = {
+    try {
+      atMost match {
+        case Duration.Inf => future.get()
+        case _ => future.get(atMost._1, atMost._2)
+      }
+    } catch {
+      case e: SparkFatalException =>
+        throw e.throwable
+      case NonFatal(t)
+        if !t.isInstanceOf[TimeoutException] && !t.isInstanceOf[RpcAbortException] =>
+        throw new SparkException("Exception thrown in awaitResult: ", t)
+    }
+  }
+
   // scalastyle:off awaitready
   /**
    * Preferred alternative to `Await.ready()`.
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala
index 59c503e..5e4f30a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala
@@ -17,11 +17,9 @@
 
 package org.apache.spark.sql.execution
 
-import java.util.concurrent.ConcurrentHashMap
+import java.util.concurrent.{ConcurrentHashMap, ExecutorService, Future => JFuture}
 import java.util.concurrent.atomic.AtomicLong
 
-import scala.concurrent.{ExecutionContext, Future}
-
 import org.apache.spark.SparkContext
 import org.apache.spark.internal.config.Tests.IS_TESTING
 import org.apache.spark.sql.SparkSession
@@ -172,11 +170,11 @@ object SQLExecution {
    * SparkContext local properties are forwarded to execution thread
    */
   def withThreadLocalCaptured[T](
-      sparkSession: SparkSession, exec: ExecutionContext)(body: => T): Future[T] = {
+      sparkSession: SparkSession, exec: ExecutorService) (body: => T): JFuture[T] = {
     val activeSession = sparkSession
     val sc = sparkSession.sparkContext
     val localProps = Utils.cloneProperties(sc.getLocalProperties)
-    Future {
+    exec.submit(() => {
       val originalSession = SparkSession.getActiveSession
       val originalLocalProps = sc.getLocalProperties
       SparkSession.setActiveSession(activeSession)
@@ -190,6 +188,6 @@ object SQLExecution {
         SparkSession.clearActiveSession()
       }
       res
-    }(exec)
+    })
   }
 }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala
index c35c484..4e65bbf 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala
@@ -17,10 +17,11 @@
 
 package org.apache.spark.sql.execution
 
+import java.util.concurrent.{Future => JFuture}
 import java.util.concurrent.TimeUnit._
 
 import scala.collection.mutable
-import scala.concurrent.{ExecutionContext, Future}
+import scala.concurrent.{ExecutionContext}
 import scala.concurrent.duration.Duration
 
 import org.apache.spark.{InterruptibleIterator, Partition, SparkContext, TaskContext}
@@ -746,7 +747,7 @@ case class SubqueryExec(name: String, child: SparkPlan)
     "collectTime" -> SQLMetrics.createTimingMetric(sparkContext, "time to collect"))
 
   @transient
-  private lazy val relationFuture: Future[Array[InternalRow]] = {
+  private lazy val relationFuture: JFuture[Array[InternalRow]] = {
     // relationFuture is used in "doExecute". Therefore we can get the execution id correctly here.
     val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY)
     SQLExecution.withThreadLocalCaptured[Array[InternalRow]](
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala
index 36f0d17..65e6b7c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala
@@ -34,7 +34,7 @@ import org.apache.spark.sql.execution.{SparkPlan, SQLExecution}
 import org.apache.spark.sql.execution.joins.HashedRelation
 import org.apache.spark.sql.execution.metric.SQLMetrics
 import org.apache.spark.sql.internal.{SQLConf, StaticSQLConf}
-import org.apache.spark.util.{SparkFatalException, ThreadUtils}
+import org.apache.spark.util.{SparkFatalException, ThreadUtils, Utils}
 
 /**
  * A [[BroadcastExchangeExec]] collects, transforms and finally broadcasts the result of
@@ -73,13 +73,8 @@ case class BroadcastExchangeExec(
 
   @transient
   private[sql] lazy val relationFuture: Future[broadcast.Broadcast[Any]] = {
-    // relationFuture is used in "doExecute". Therefore we can get the execution id correctly here.
-    val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY)
-    val task = new Callable[broadcast.Broadcast[Any]]() {
-      override def call(): broadcast.Broadcast[Any] = {
-        // This will run in another thread. Set the execution id so that we can connect these jobs
-        // with the correct execution.
-        SQLExecution.withExecutionId(sqlContext.sparkSession, executionId) {
+    SQLExecution.withThreadLocalCaptured[broadcast.Broadcast[Any]](
+      sqlContext.sparkSession, BroadcastExchangeExec.executionContext) {
           try {
             // Setup a job group here so later it may get cancelled by groupId if necessary.
             sparkContext.setJobGroup(runId.toString, s"broadcast exchange (runId $runId)",
@@ -121,7 +116,7 @@ case class BroadcastExchangeExec(
             val broadcasted = sparkContext.broadcast(relation)
             longMetric("broadcastTime") += NANOSECONDS.toMillis(
               System.nanoTime() - beforeBroadcast)
-
+            val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY)
             SQLMetrics.postDriverMetricUpdates(sparkContext, executionId, metrics.values.toSeq)
             promise.success(broadcasted)
             broadcasted
@@ -146,10 +141,7 @@ case class BroadcastExchangeExec(
               promise.failure(e)
               throw e
           }
-        }
-      }
     }
-    BroadcastExchangeExec.executionContext.submit[broadcast.Broadcast[Any]](task)
   }
 
   override protected def doPrepare(): Unit = {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/internal/ExecutorSideSQLConfSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/internal/ExecutorSideSQLConfSuite.scala
index 46d0c64..888772c 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/internal/ExecutorSideSQLConfSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/internal/ExecutorSideSQLConfSuite.scala
@@ -159,6 +159,34 @@ class ExecutorSideSQLConfSuite extends SparkFunSuite with SQLTestUtils {
       }
     }
   }
+
+  test("SPARK-22590 propagate local properties to broadcast execution thread") {
+    withSQLConf(StaticSQLConf.BROADCAST_EXCHANGE_MAX_THREAD_THRESHOLD.key -> "1") {
+      val df1 = Seq(true).toDF()
+      val confKey = "spark.sql.y"
+      val confValue1 = UUID.randomUUID().toString()
+      val confValue2 = UUID.randomUUID().toString()
+
+      def generateBroadcastDataFrame(confKey: String, confValue: String): Dataset[Boolean] = {
+        val df = spark.range(1).mapPartitions { _ =>
+          Iterator(TaskContext.get.getLocalProperty(confKey) == confValue)
+        }
+        df.hint("broadcast")
+      }
+
+      // set local propert and assert
+      val df2 = generateBroadcastDataFrame(confKey, confValue1)
+      spark.sparkContext.setLocalProperty(confKey, confValue1)
+      val checks = df1.join(df2).collect()
+      assert(checks.forall(_.toSeq == Seq(true, true)))
+
+      // change local property and re-assert
+      val df3 = generateBroadcastDataFrame(confKey, confValue2)
+      spark.sparkContext.setLocalProperty(confKey, confValue2)
+      val checks2 = df1.join(df3).collect()
+      assert(checks2.forall(_.toSeq == Seq(true, true)))
+    }
+  }
 }
 
 case class SQLConfAssertPlan(confToCheck: Seq[(String, String)]) extends LeafExecNode {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org