You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by rx...@apache.org on 2015/05/16 09:44:32 UTC

spark git commit: [SPARK-7655][Core][SQL] Remove 'scala.concurrent.ExecutionContext.Implicits.global' in 'ask' and 'BroadcastHashJoin'

Repository: spark
Updated Branches:
  refs/heads/master 0ac8b01a0 -> 47e7ffe36


[SPARK-7655][Core][SQL] Remove 'scala.concurrent.ExecutionContext.Implicits.global' in 'ask' and 'BroadcastHashJoin'

Because both `AkkaRpcEndpointRef.ask` and `BroadcastHashJoin` uses `scala.concurrent.ExecutionContext.Implicits.global`. However, because the tasks in `BroadcastHashJoin` are usually long-running tasks, which will occupy all threads in `global`. Then `ask` cannot get a chance to process the replies.

For `ask`, actually the tasks are very simple, so we can use `MoreExecutors.sameThreadExecutor()`. For `BroadcastHashJoin`, it's better to use `ThreadUtils.newDaemonCachedThreadPool`.

Author: zsxwing <zs...@gmail.com>

Closes #6200 from zsxwing/SPARK-7655-2 and squashes the following commits:

cfdc605 [zsxwing] Remove redundant imort and minor doc fix
cf83153 [zsxwing] Add "sameThread" and "newDaemonCachedThreadPool with maxThreadNumber" to ThreadUtils
08ad0ee [zsxwing] Remove 'scala.concurrent.ExecutionContext.Implicits.global' in 'ask' and 'BroadcastHashJoin'


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/47e7ffe3
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/47e7ffe3
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/47e7ffe3

Branch: refs/heads/master
Commit: 47e7ffe36b8a8a246fe9af522aff480d19c0c8a6
Parents: 0ac8b01
Author: zsxwing <zs...@gmail.com>
Authored: Sat May 16 00:44:29 2015 -0700
Committer: Reynold Xin <rx...@databricks.com>
Committed: Sat May 16 00:44:29 2015 -0700

----------------------------------------------------------------------
 .../org/apache/spark/rpc/akka/AkkaRpcEnv.scala  |  8 ++++---
 .../org/apache/spark/util/ThreadUtils.scala     | 24 +++++++++++++++++++-
 .../apache/spark/util/ThreadUtilsSuite.scala    | 12 ++++++++++
 .../sql/execution/joins/BroadcastHashJoin.scala | 10 ++++++--
 4 files changed, 48 insertions(+), 6 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/47e7ffe3/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala
index ba0d468..0161962 100644
--- a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala
+++ b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala
@@ -29,9 +29,11 @@ import akka.actor.{ActorSystem, ExtendedActorSystem, Actor, ActorRef, Props, Add
 import akka.event.Logging.Error
 import akka.pattern.{ask => akkaAsk}
 import akka.remote.{AssociationEvent, AssociatedEvent, DisassociatedEvent, AssociationErrorEvent}
+import com.google.common.util.concurrent.MoreExecutors
+
 import org.apache.spark.{SparkException, Logging, SparkConf}
 import org.apache.spark.rpc._
-import org.apache.spark.util.{ActorLogReceive, AkkaUtils}
+import org.apache.spark.util.{ActorLogReceive, AkkaUtils, ThreadUtils}
 
 /**
  * A RpcEnv implementation based on Akka.
@@ -294,8 +296,8 @@ private[akka] class AkkaRpcEndpointRef(
   }
 
   override def ask[T: ClassTag](message: Any, timeout: FiniteDuration): Future[T] = {
-    import scala.concurrent.ExecutionContext.Implicits.global
     actorRef.ask(AkkaMessage(message, true))(timeout).flatMap {
+      // The function will run in the calling thread, so it should be short and never block.
       case msg @ AkkaMessage(message, reply) =>
         if (reply) {
           logError(s"Receive $msg but the sender cannot reply")
@@ -305,7 +307,7 @@ private[akka] class AkkaRpcEndpointRef(
         }
       case AkkaFailure(e) =>
         Future.failed(e)
-    }.mapTo[T]
+    }(ThreadUtils.sameThread).mapTo[T]
   }
 
   override def toString: String = s"${getClass.getSimpleName}($actorRef)"

http://git-wip-us.apache.org/repos/asf/spark/blob/47e7ffe3/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala b/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala
index 098a4b7..ca5624a 100644
--- a/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala
+++ b/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala
@@ -20,10 +20,22 @@ package org.apache.spark.util
 
 import java.util.concurrent._
 
-import com.google.common.util.concurrent.ThreadFactoryBuilder
+import scala.concurrent.{ExecutionContext, ExecutionContextExecutor}
+
+import com.google.common.util.concurrent.{MoreExecutors, ThreadFactoryBuilder}
 
 private[spark] object ThreadUtils {
 
+  private val sameThreadExecutionContext =
+    ExecutionContext.fromExecutorService(MoreExecutors.sameThreadExecutor())
+
+  /**
+   * An `ExecutionContextExecutor` that runs each task in the thread that invokes `execute/submit`.
+   * The caller should make sure the tasks running in this `ExecutionContextExecutor` are short and
+   * never block.
+   */
+  def sameThread: ExecutionContextExecutor = sameThreadExecutionContext
+
   /**
    * Create a thread factory that names threads with a prefix and also sets the threads to daemon.
    */
@@ -41,6 +53,16 @@ private[spark] object ThreadUtils {
   }
 
   /**
+   * Create a cached thread pool whose max number of threads is `maxThreadNumber`. Thread names
+   * are formatted as prefix-ID, where ID is a unique, sequentially assigned integer.
+   */
+  def newDaemonCachedThreadPool(prefix: String, maxThreadNumber: Int): ThreadPoolExecutor = {
+    val threadFactory = namedThreadFactory(prefix)
+    new ThreadPoolExecutor(
+      0, maxThreadNumber, 60L, TimeUnit.SECONDS, new SynchronousQueue[Runnable], threadFactory)
+  }
+
+  /**
    * Wrapper over newFixedThreadPool. Thread names are formatted as prefix-ID, where ID is a
    * unique, sequentially assigned integer.
    */

http://git-wip-us.apache.org/repos/asf/spark/blob/47e7ffe3/core/src/test/scala/org/apache/spark/util/ThreadUtilsSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/util/ThreadUtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/ThreadUtilsSuite.scala
index a3aa3e9..751d3df 100644
--- a/core/src/test/scala/org/apache/spark/util/ThreadUtilsSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/ThreadUtilsSuite.scala
@@ -20,6 +20,9 @@ package org.apache.spark.util
 
 import java.util.concurrent.{CountDownLatch, TimeUnit}
 
+import scala.concurrent.{Await, Future}
+import scala.concurrent.duration._
+
 import org.scalatest.FunSuite
 
 class ThreadUtilsSuite extends FunSuite {
@@ -54,4 +57,13 @@ class ThreadUtilsSuite extends FunSuite {
       executor.shutdownNow()
     }
   }
+
+  test("sameThread") {
+    val callerThreadName = Thread.currentThread().getName()
+    val f = Future {
+      Thread.currentThread().getName()
+    }(ThreadUtils.sameThread)
+    val futureThreadName = Await.result(f, 10.seconds)
+    assert(futureThreadName === callerThreadName)
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/47e7ffe3/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala
index 05dd568..fe43fc4 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala
@@ -18,10 +18,10 @@
 package org.apache.spark.sql.execution.joins
 
 import org.apache.spark.rdd.RDD
+import org.apache.spark.util.ThreadUtils
 
 import scala.concurrent._
 import scala.concurrent.duration._
-import scala.concurrent.ExecutionContext.Implicits.global
 
 import org.apache.spark.annotation.DeveloperApi
 import org.apache.spark.sql.catalyst.expressions.{Row, Expression}
@@ -64,7 +64,7 @@ case class BroadcastHashJoin(
     val input: Array[Row] = buildPlan.execute().map(_.copy()).collect()
     val hashed = HashedRelation(input.iterator, buildSideKeyGenerator, input.length)
     sparkContext.broadcast(hashed)
-  }
+  }(BroadcastHashJoin.broadcastHashJoinExecutionContext)
 
   protected override def doExecute(): RDD[Row] = {
     val broadcastRelation = Await.result(broadcastFuture, timeout)
@@ -74,3 +74,9 @@ case class BroadcastHashJoin(
     }
   }
 }
+
+object BroadcastHashJoin {
+
+  private val broadcastHashJoinExecutionContext = ExecutionContext.fromExecutorService(
+    ThreadUtils.newDaemonCachedThreadPool("broadcast-hash-join", 1024))
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org