You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by jo...@apache.org on 2014/09/14 01:22:12 UTC

git commit: [SPARK-3030] [PySpark] Reuse Python worker

Repository: spark
Updated Branches:
  refs/heads/master 0f8c4edf4 -> 2aea0da84


[SPARK-3030] [PySpark] Reuse Python worker

Reuse Python worker to avoid the overhead of fork() Python process for each tasks. It also tracks the broadcasts for each worker, avoid sending repeated broadcasts.

This can reduce the time for dummy task from 22ms to 13ms (-40%). It can help to reduce the latency for Spark Streaming.

For a job with broadcast (43M after compress):
```
    b = sc.broadcast(set(range(30000000)))
    print sc.parallelize(range(24000), 100).filter(lambda x: x in b.value).count()
```
It will finish in 281s without reused worker, and it will finish in 65s with reused worker(4 CPUs). After reusing the worker, it can save about 9 seconds for transfer and deserialize the broadcast for each tasks.

It's enabled by default, could be disabled by `spark.python.worker.reuse = false`.

Author: Davies Liu <da...@gmail.com>

Closes #2259 from davies/reuse-worker and squashes the following commits:

f11f617 [Davies Liu] Merge branch 'master' into reuse-worker
3939f20 [Davies Liu] fix bug in serializer in mllib
cf1c55e [Davies Liu] address comments
3133a60 [Davies Liu] fix accumulator with reused worker
760ab1f [Davies Liu] do not reuse worker if there are any exceptions
7abb224 [Davies Liu] refactor: sychronized with itself
ac3206e [Davies Liu] renaming
8911f44 [Davies Liu] synchronized getWorkerBroadcasts()
6325fc1 [Davies Liu] bugfix: bid >= 0
e0131a2 [Davies Liu] fix name of config
583716e [Davies Liu] only reuse completed and not interrupted worker
ace2917 [Davies Liu] kill python worker after timeout
6123d0f [Davies Liu] track broadcasts for each worker
8d2f08c [Davies Liu] reuse python worker


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/2aea0da8
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/2aea0da8
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/2aea0da8

Branch: refs/heads/master
Commit: 2aea0da84c58a179917311290083456dfa043db7
Parents: 0f8c4ed
Author: Davies Liu <da...@gmail.com>
Authored: Sat Sep 13 16:22:04 2014 -0700
Committer: Josh Rosen <jo...@apache.org>
Committed: Sat Sep 13 16:22:04 2014 -0700

----------------------------------------------------------------------
 .../main/scala/org/apache/spark/SparkEnv.scala  |  8 ++
 .../org/apache/spark/api/python/PythonRDD.scala | 58 ++++++++++---
 .../spark/api/python/PythonWorkerFactory.scala  | 85 +++++++++++++++++---
 docs/configuration.md                           | 10 +++
 python/pyspark/daemon.py                        | 38 ++++-----
 python/pyspark/mllib/_common.py                 | 12 ++-
 python/pyspark/serializers.py                   |  4 +
 python/pyspark/tests.py                         | 35 ++++++++
 python/pyspark/worker.py                        |  9 ++-
 9 files changed, 208 insertions(+), 51 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/2aea0da8/core/src/main/scala/org/apache/spark/SparkEnv.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala
index dd95e40..009ed64 100644
--- a/core/src/main/scala/org/apache/spark/SparkEnv.scala
+++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala
@@ -108,6 +108,14 @@ class SparkEnv (
       pythonWorkers.get(key).foreach(_.stopWorker(worker))
     }
   }
+
+  private[spark]
+  def releasePythonWorker(pythonExec: String, envVars: Map[String, String], worker: Socket) {
+    synchronized {
+      val key = (pythonExec, envVars)
+      pythonWorkers.get(key).foreach(_.releaseWorker(worker))
+    }
+  }
 }
 
 object SparkEnv extends Logging {

http://git-wip-us.apache.org/repos/asf/spark/blob/2aea0da8/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
index ae80103..ca8eef5 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
@@ -23,6 +23,7 @@ import java.nio.charset.Charset
 import java.util.{List => JList, ArrayList => JArrayList, Map => JMap, Collections}
 
 import scala.collection.JavaConversions._
+import scala.collection.mutable
 import scala.language.existentials
 import scala.reflect.ClassTag
 import scala.util.{Try, Success, Failure}
@@ -52,6 +53,7 @@ private[spark] class PythonRDD(
   extends RDD[Array[Byte]](parent) {
 
   val bufferSize = conf.getInt("spark.buffer.size", 65536)
+  val reuse_worker = conf.getBoolean("spark.python.worker.reuse", true)
 
   override def getPartitions = parent.partitions
 
@@ -63,19 +65,26 @@ private[spark] class PythonRDD(
     val localdir = env.blockManager.diskBlockManager.localDirs.map(
       f => f.getPath()).mkString(",")
     envVars += ("SPARK_LOCAL_DIRS" -> localdir) // it's also used in monitor thread
+    if (reuse_worker) {
+      envVars += ("SPARK_REUSE_WORKER" -> "1")
+    }
     val worker: Socket = env.createPythonWorker(pythonExec, envVars.toMap)
 
     // Start a thread to feed the process input from our parent's iterator
     val writerThread = new WriterThread(env, worker, split, context)
 
+    var complete_cleanly = false
     context.addTaskCompletionListener { context =>
       writerThread.shutdownOnTaskCompletion()
-
-      // Cleanup the worker socket. This will also cause the Python worker to exit.
-      try {
-        worker.close()
-      } catch {
-        case e: Exception => logWarning("Failed to close worker socket", e)
+      if (reuse_worker && complete_cleanly) {
+        env.releasePythonWorker(pythonExec, envVars.toMap, worker)
+      } else {
+        try {
+          worker.close()
+        } catch {
+          case e: Exception =>
+            logWarning("Failed to close worker socket", e)
+        }
       }
     }
 
@@ -133,6 +142,7 @@ private[spark] class PythonRDD(
                 stream.readFully(update)
                 accumulator += Collections.singletonList(update)
               }
+               complete_cleanly = true
               null
           }
         } catch {
@@ -195,11 +205,26 @@ private[spark] class PythonRDD(
           PythonRDD.writeUTF(include, dataOut)
         }
         // Broadcast variables
-        dataOut.writeInt(broadcastVars.length)
+        val oldBids = PythonRDD.getWorkerBroadcasts(worker)
+        val newBids = broadcastVars.map(_.id).toSet
+        // number of different broadcasts
+        val cnt = oldBids.diff(newBids).size + newBids.diff(oldBids).size
+        dataOut.writeInt(cnt)
+        for (bid <- oldBids) {
+          if (!newBids.contains(bid)) {
+            // remove the broadcast from worker
+            dataOut.writeLong(- bid - 1)  // bid >= 0
+            oldBids.remove(bid)
+          }
+        }
         for (broadcast <- broadcastVars) {
-          dataOut.writeLong(broadcast.id)
-          dataOut.writeInt(broadcast.value.length)
-          dataOut.write(broadcast.value)
+          if (!oldBids.contains(broadcast.id)) {
+            // send new broadcast
+            dataOut.writeLong(broadcast.id)
+            dataOut.writeInt(broadcast.value.length)
+            dataOut.write(broadcast.value)
+            oldBids.add(broadcast.id)
+          }
         }
         dataOut.flush()
         // Serialized command:
@@ -207,17 +232,18 @@ private[spark] class PythonRDD(
         dataOut.write(command)
         // Data values
         PythonRDD.writeIteratorToStream(parent.iterator(split, context), dataOut)
+        dataOut.writeInt(SpecialLengths.END_OF_DATA_SECTION)
         dataOut.flush()
       } catch {
         case e: Exception if context.isCompleted || context.isInterrupted =>
           logDebug("Exception thrown after task completion (likely due to cleanup)", e)
+          worker.shutdownOutput()
 
         case e: Exception =>
           // We must avoid throwing exceptions here, because the thread uncaught exception handler
           // will kill the whole executor (see org.apache.spark.executor.Executor).
           _exception = e
-      } finally {
-        Try(worker.shutdownOutput()) // kill Python worker process
+          worker.shutdownOutput()
       }
     }
   }
@@ -278,6 +304,14 @@ private object SpecialLengths {
 private[spark] object PythonRDD extends Logging {
   val UTF8 = Charset.forName("UTF-8")
 
+  // remember the broadcasts sent to each worker
+  private val workerBroadcasts = new mutable.WeakHashMap[Socket, mutable.Set[Long]]()
+  private def getWorkerBroadcasts(worker: Socket) = {
+    synchronized {
+      workerBroadcasts.getOrElseUpdate(worker, new mutable.HashSet[Long]())
+    }
+  }
+
   /**
    * Adapter for calling SparkContext#runJob from Python.
    *

http://git-wip-us.apache.org/repos/asf/spark/blob/2aea0da8/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala
index 4c4796f..71bdf0f 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala
@@ -40,7 +40,10 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String
   var daemon: Process = null
   val daemonHost = InetAddress.getByAddress(Array(127, 0, 0, 1))
   var daemonPort: Int = 0
-  var daemonWorkers = new mutable.WeakHashMap[Socket, Int]()
+  val daemonWorkers = new mutable.WeakHashMap[Socket, Int]()
+  val idleWorkers = new mutable.Queue[Socket]()
+  var lastActivity = 0L
+  new MonitorThread().start()
 
   var simpleWorkers = new mutable.WeakHashMap[Socket, Process]()
 
@@ -51,6 +54,11 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String
 
   def create(): Socket = {
     if (useDaemon) {
+      synchronized {
+        if (idleWorkers.size > 0) {
+          return idleWorkers.dequeue()
+        }
+      }
       createThroughDaemon()
     } else {
       createSimpleWorker()
@@ -199,9 +207,44 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String
     }
   }
 
+  /**
+   * Monitor all the idle workers, kill them after timeout.
+   */
+  private class MonitorThread extends Thread(s"Idle Worker Monitor for $pythonExec") {
+
+    setDaemon(true)
+
+    override def run() {
+      while (true) {
+        synchronized {
+          if (lastActivity + IDLE_WORKER_TIMEOUT_MS < System.currentTimeMillis()) {
+            cleanupIdleWorkers()
+            lastActivity = System.currentTimeMillis()
+          }
+        }
+        Thread.sleep(10000)
+      }
+    }
+  }
+
+  private def cleanupIdleWorkers() {
+    while (idleWorkers.length > 0) {
+      val worker = idleWorkers.dequeue()
+      try {
+        // the worker will exit after closing the socket
+        worker.close()
+      } catch {
+        case e: Exception =>
+          logWarning("Failed to close worker socket", e)
+      }
+    }
+  }
+
   private def stopDaemon() {
     synchronized {
       if (useDaemon) {
+        cleanupIdleWorkers()
+
         // Request shutdown of existing daemon by sending SIGTERM
         if (daemon != null) {
           daemon.destroy()
@@ -220,23 +263,43 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String
   }
 
   def stopWorker(worker: Socket) {
-    if (useDaemon) {
-      if (daemon != null) {
-        daemonWorkers.get(worker).foreach { pid =>
-          // tell daemon to kill worker by pid
-          val output = new DataOutputStream(daemon.getOutputStream)
-          output.writeInt(pid)
-          output.flush()
-          daemon.getOutputStream.flush()
+    synchronized {
+      if (useDaemon) {
+        if (daemon != null) {
+          daemonWorkers.get(worker).foreach { pid =>
+            // tell daemon to kill worker by pid
+            val output = new DataOutputStream(daemon.getOutputStream)
+            output.writeInt(pid)
+            output.flush()
+            daemon.getOutputStream.flush()
+          }
         }
+      } else {
+        simpleWorkers.get(worker).foreach(_.destroy())
       }
-    } else {
-      simpleWorkers.get(worker).foreach(_.destroy())
     }
     worker.close()
   }
+
+  def releaseWorker(worker: Socket) {
+    if (useDaemon) {
+      synchronized {
+        lastActivity = System.currentTimeMillis()
+        idleWorkers.enqueue(worker)
+      }
+    } else {
+      // Cleanup the worker socket. This will also cause the Python worker to exit.
+      try {
+        worker.close()
+      } catch {
+        case e: Exception =>
+          logWarning("Failed to close worker socket", e)
+      }
+    }
+  }
 }
 
 private object PythonWorkerFactory {
   val PROCESS_WAIT_TIMEOUT_MS = 10000
+  val IDLE_WORKER_TIMEOUT_MS = 60000  // kill idle workers after 1 minute
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/2aea0da8/docs/configuration.md
----------------------------------------------------------------------
diff --git a/docs/configuration.md b/docs/configuration.md
index 36178ef..af16489 100644
--- a/docs/configuration.md
+++ b/docs/configuration.md
@@ -207,6 +207,16 @@ Apart from these, the following properties are also available, and may be useful
   </td>
 </tr>
 <tr>
+  <td><code>spark.python.worker.reuse</code></td>
+  <td>true</td>
+  <td>
+    Reuse Python worker or not. If yes, it will use a fixed number of Python workers,
+    does not need to fork() a Python process for every tasks. It will be very useful
+    if there is large broadcast, then the broadcast will not be needed to transfered
+    from JVM to Python worker for every task.
+  </td>
+</tr>
+<tr>
   <td><code>spark.executorEnv.[EnvironmentVariableName]</code></td>
   <td>(none)</td>
   <td>

http://git-wip-us.apache.org/repos/asf/spark/blob/2aea0da8/python/pyspark/daemon.py
----------------------------------------------------------------------
diff --git a/python/pyspark/daemon.py b/python/pyspark/daemon.py
index 15445ab..64d6202 100644
--- a/python/pyspark/daemon.py
+++ b/python/pyspark/daemon.py
@@ -23,6 +23,7 @@ import socket
 import sys
 import traceback
 import time
+import gc
 from errno import EINTR, ECHILD, EAGAIN
 from socket import AF_INET, SOCK_STREAM, SOMAXCONN
 from signal import SIGHUP, SIGTERM, SIGCHLD, SIG_DFL, SIG_IGN
@@ -46,17 +47,6 @@ def worker(sock):
     signal.signal(SIGCHLD, SIG_DFL)
     signal.signal(SIGTERM, SIG_DFL)
 
-    # Blocks until the socket is closed by draining the input stream
-    # until it raises an exception or returns EOF.
-    def waitSocketClose(sock):
-        try:
-            while True:
-                # Empty string is returned upon EOF (and only then).
-                if sock.recv(4096) == '':
-                    return
-        except:
-            pass
-
     # Read the socket using fdopen instead of socket.makefile() because the latter
     # seems to be very slow; note that we need to dup() the file descriptor because
     # otherwise writes also cause a seek that makes us miss data on the read side.
@@ -64,17 +54,13 @@ def worker(sock):
     outfile = os.fdopen(os.dup(sock.fileno()), "a+", 65536)
     exit_code = 0
     try:
-        # Acknowledge that the fork was successful
-        write_int(os.getpid(), outfile)
-        outfile.flush()
         worker_main(infile, outfile)
     except SystemExit as exc:
-        exit_code = exc.code
+        exit_code = compute_real_exit_code(exc.code)
     finally:
         outfile.flush()
-        # The Scala side will close the socket upon task completion.
-        waitSocketClose(sock)
-        os._exit(compute_real_exit_code(exit_code))
+        if exit_code:
+            os._exit(exit_code)
 
 
 # Cleanup zombie children
@@ -111,6 +97,8 @@ def manager():
     signal.signal(SIGTERM, handle_sigterm)  # Gracefully exit on SIGTERM
     signal.signal(SIGHUP, SIG_IGN)  # Don't die on SIGHUP
 
+    reuse = os.environ.get("SPARK_REUSE_WORKER")
+
     # Initialization complete
     try:
         while True:
@@ -163,7 +151,19 @@ def manager():
                     # in child process
                     listen_sock.close()
                     try:
-                        worker(sock)
+                        # Acknowledge that the fork was successful
+                        outfile = sock.makefile("w")
+                        write_int(os.getpid(), outfile)
+                        outfile.flush()
+                        outfile.close()
+                        while True:
+                            worker(sock)
+                            if not reuse:
+                                # wait for closing
+                                while sock.recv(1024):
+                                    pass
+                                break
+                            gc.collect()
                     except:
                         traceback.print_exc()
                         os._exit(1)

http://git-wip-us.apache.org/repos/asf/spark/blob/2aea0da8/python/pyspark/mllib/_common.py
----------------------------------------------------------------------
diff --git a/python/pyspark/mllib/_common.py b/python/pyspark/mllib/_common.py
index bb60d3d..68f6033 100644
--- a/python/pyspark/mllib/_common.py
+++ b/python/pyspark/mllib/_common.py
@@ -21,7 +21,7 @@ import numpy
 from numpy import ndarray, float64, int64, int32, array_equal, array
 from pyspark import SparkContext, RDD
 from pyspark.mllib.linalg import SparseVector
-from pyspark.serializers import Serializer
+from pyspark.serializers import FramedSerializer
 
 
 """
@@ -451,18 +451,16 @@ def _serialize_rating(r):
     return ba
 
 
-class RatingDeserializer(Serializer):
+class RatingDeserializer(FramedSerializer):
 
-    def loads(self, stream):
-        length = struct.unpack("!i", stream.read(4))[0]
-        ba = stream.read(length)
-        res = ndarray(shape=(3, ), buffer=ba, dtype=float64, offset=4)
+    def loads(self, string):
+        res = ndarray(shape=(3, ), buffer=string, dtype=float64, offset=4)
         return int(res[0]), int(res[1]), res[2]
 
     def load_stream(self, stream):
         while True:
             try:
-                yield self.loads(stream)
+                yield self._read_with_length(stream)
             except struct.error:
                 return
             except EOFError:

http://git-wip-us.apache.org/repos/asf/spark/blob/2aea0da8/python/pyspark/serializers.py
----------------------------------------------------------------------
diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py
index a5f9341..ec3c6f0 100644
--- a/python/pyspark/serializers.py
+++ b/python/pyspark/serializers.py
@@ -144,6 +144,8 @@ class FramedSerializer(Serializer):
 
     def _read_with_length(self, stream):
         length = read_int(stream)
+        if length == SpecialLengths.END_OF_DATA_SECTION:
+            raise EOFError
         obj = stream.read(length)
         if obj == "":
             raise EOFError
@@ -438,6 +440,8 @@ class UTF8Deserializer(Serializer):
 
     def loads(self, stream):
         length = read_int(stream)
+        if length == SpecialLengths.END_OF_DATA_SECTION:
+            raise EOFError
         s = stream.read(length)
         return s.decode("utf-8") if self.use_unicode else s
 

http://git-wip-us.apache.org/repos/asf/spark/blob/2aea0da8/python/pyspark/tests.py
----------------------------------------------------------------------
diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py
index b687d69..747cd17 100644
--- a/python/pyspark/tests.py
+++ b/python/pyspark/tests.py
@@ -1222,11 +1222,46 @@ class TestWorker(PySparkTestCase):
         except OSError:
             self.fail("daemon had been killed")
 
+        # run a normal job
+        rdd = self.sc.parallelize(range(100), 1)
+        self.assertEqual(100, rdd.map(str).count())
+
     def test_fd_leak(self):
         N = 1100  # fd limit is 1024 by default
         rdd = self.sc.parallelize(range(N), N)
         self.assertEquals(N, rdd.count())
 
+    def test_after_exception(self):
+        def raise_exception(_):
+            raise Exception()
+        rdd = self.sc.parallelize(range(100), 1)
+        self.assertRaises(Exception, lambda: rdd.foreach(raise_exception))
+        self.assertEqual(100, rdd.map(str).count())
+
+    def test_after_jvm_exception(self):
+        tempFile = tempfile.NamedTemporaryFile(delete=False)
+        tempFile.write("Hello World!")
+        tempFile.close()
+        data = self.sc.textFile(tempFile.name, 1)
+        filtered_data = data.filter(lambda x: True)
+        self.assertEqual(1, filtered_data.count())
+        os.unlink(tempFile.name)
+        self.assertRaises(Exception, lambda: filtered_data.count())
+
+        rdd = self.sc.parallelize(range(100), 1)
+        self.assertEqual(100, rdd.map(str).count())
+
+    def test_accumulator_when_reuse_worker(self):
+        from pyspark.accumulators import INT_ACCUMULATOR_PARAM
+        acc1 = self.sc.accumulator(0, INT_ACCUMULATOR_PARAM)
+        self.sc.parallelize(range(100), 20).foreach(lambda x: acc1.add(x))
+        self.assertEqual(sum(range(100)), acc1.value)
+
+        acc2 = self.sc.accumulator(0, INT_ACCUMULATOR_PARAM)
+        self.sc.parallelize(range(100), 20).foreach(lambda x: acc2.add(x))
+        self.assertEqual(sum(range(100)), acc2.value)
+        self.assertEqual(sum(range(100)), acc1.value)
+
 
 class TestSparkSubmit(unittest.TestCase):
 

http://git-wip-us.apache.org/repos/asf/spark/blob/2aea0da8/python/pyspark/worker.py
----------------------------------------------------------------------
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index 6805063..61b8a74 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -69,9 +69,14 @@ def main(infile, outfile):
         ser = CompressedSerializer(pickleSer)
         for _ in range(num_broadcast_variables):
             bid = read_long(infile)
-            value = ser._read_with_length(infile)
-            _broadcastRegistry[bid] = Broadcast(bid, value)
+            if bid >= 0:
+                value = ser._read_with_length(infile)
+                _broadcastRegistry[bid] = Broadcast(bid, value)
+            else:
+                bid = - bid - 1
+                _broadcastRegistry.remove(bid)
 
+        _accumulatorRegistry.clear()
         command = pickleSer._read_with_length(infile)
         (func, deserializer, serializer) = command
         init_time = time.time()


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org