You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ma...@apache.org on 2013/09/03 03:38:35 UTC

[03/19] git commit: Allow PySpark to launch worker.py directly on Windows

Allow PySpark to launch worker.py directly on Windows


Project: http://git-wip-us.apache.org/repos/asf/incubator-spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-spark/commit/6550e5e6
Tree: http://git-wip-us.apache.org/repos/asf/incubator-spark/tree/6550e5e6
Diff: http://git-wip-us.apache.org/repos/asf/incubator-spark/diff/6550e5e6

Branch: refs/heads/master
Commit: 6550e5e60c501cbce40f0e968fc674e499f21949
Parents: 3c520fe
Author: Matei Zaharia <ma...@eecs.berkeley.edu>
Authored: Sun Sep 1 18:06:15 2013 -0700
Committer: Matei Zaharia <ma...@eecs.berkeley.edu>
Committed: Sun Sep 1 18:06:15 2013 -0700

----------------------------------------------------------------------
 .../spark/api/python/PythonWorkerFactory.scala  | 107 +++++++++++++++++--
 python/pyspark/worker.py                        |  11 +-
 2 files changed, 106 insertions(+), 12 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-spark/blob/6550e5e6/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala
index 08e3f67..67d4572 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala
@@ -17,8 +17,8 @@
 
 package org.apache.spark.api.python
 
-import java.io.{File, DataInputStream, IOException}
-import java.net.{Socket, SocketException, InetAddress}
+import java.io.{OutputStreamWriter, File, DataInputStream, IOException}
+import java.net.{ServerSocket, Socket, SocketException, InetAddress}
 
 import scala.collection.JavaConversions._
 
@@ -26,11 +26,30 @@ import org.apache.spark._
 
 private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String, String])
     extends Logging {
+
+  // Because forking processes from Java is expensive, we prefer to launch a single Python daemon
+  // (pyspark/daemon.py) and tell it to fork new workers for our tasks. This daemon currently
+  // only works on UNIX-based systems now because it uses signals for child management, so we can
+  // also fall back to launching workers (pyspark/worker.py) directly.
+  val useDaemon = !System.getProperty("os.name").startsWith("Windows")
+
   var daemon: Process = null
   val daemonHost = InetAddress.getByAddress(Array(127, 0, 0, 1))
   var daemonPort: Int = 0
 
   def create(): Socket = {
+    if (useDaemon) {
+      createThroughDaemon()
+    } else {
+      createSimpleWorker()
+    }
+  }
+
+  /**
+   * Connect to a worker launched through pyspark/daemon.py, which forks python processes itself
+   * to avoid the high cost of forking from Java. This currently only works on UNIX-based systems.
+   */
+  private def createThroughDaemon(): Socket = {
     synchronized {
       // Start the daemon if it hasn't been started
       startDaemon()
@@ -50,6 +69,78 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String
     }
   }
 
+  /**
+   * Launch a worker by executing worker.py directly and telling it to connect to us.
+   */
+  private def createSimpleWorker(): Socket = {
+    var serverSocket: ServerSocket = null
+    try {
+      serverSocket = new ServerSocket(0, 1, InetAddress.getByAddress(Array(127, 0, 0, 1)))
+
+      // Create and start the worker
+      val sparkHome = new ProcessBuilder().environment().get("SPARK_HOME")
+      val pb = new ProcessBuilder(Seq(pythonExec, sparkHome + "/python/pyspark/worker.py"))
+      val workerEnv = pb.environment()
+      workerEnv.putAll(envVars)
+      val pythonPath = sparkHome + "/python/" + File.pathSeparator + workerEnv.get("PYTHONPATH")
+      workerEnv.put("PYTHONPATH", pythonPath)
+      val worker = pb.start()
+
+      // Redirect the worker's stderr to ours
+      new Thread("stderr reader for " + pythonExec) {
+        setDaemon(true)
+        override def run() {
+          scala.util.control.Exception.ignoring(classOf[IOException]) {
+            // FIXME: We copy the stream on the level of bytes to avoid encoding problems.
+            val in = worker.getErrorStream
+            val buf = new Array[Byte](1024)
+            var len = in.read(buf)
+            while (len != -1) {
+              System.err.write(buf, 0, len)
+              len = in.read(buf)
+            }
+          }
+        }
+      }.start()
+
+      // Redirect worker's stdout to our stderr
+      new Thread("stdout reader for " + pythonExec) {
+        setDaemon(true)
+        override def run() {
+          scala.util.control.Exception.ignoring(classOf[IOException]) {
+            // FIXME: We copy the stream on the level of bytes to avoid encoding problems.
+            val in = worker.getInputStream
+            val buf = new Array[Byte](1024)
+            var len = in.read(buf)
+            while (len != -1) {
+              System.err.write(buf, 0, len)
+              len = in.read(buf)
+            }
+          }
+        }
+      }.start()
+
+      // Tell the worker our port
+      val out = new OutputStreamWriter(worker.getOutputStream)
+      out.write(serverSocket.getLocalPort + "\n")
+      out.flush()
+
+      // Wait for it to connect to our socket
+      serverSocket.setSoTimeout(10000)
+      try {
+        return serverSocket.accept()
+      } catch {
+        case e: Exception =>
+          throw new SparkException("Python worker did not connect back in time", e)
+      }
+    } finally {
+      if (serverSocket != null) {
+        serverSocket.close()
+      }
+    }
+    null
+  }
+
   def stop() {
     stopDaemon()
   }
@@ -73,12 +164,12 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String
 
         // Redirect the stderr to ours
         new Thread("stderr reader for " + pythonExec) {
+          setDaemon(true)
           override def run() {
             scala.util.control.Exception.ignoring(classOf[IOException]) {
-              // FIXME HACK: We copy the stream on the level of bytes to
-              // attempt to dodge encoding problems.
+              // FIXME: We copy the stream on the level of bytes to avoid encoding problems.
               val in = daemon.getErrorStream
-              var buf = new Array[Byte](1024)
+              val buf = new Array[Byte](1024)
               var len = in.read(buf)
               while (len != -1) {
                 System.err.write(buf, 0, len)
@@ -93,11 +184,11 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String
 
         // Redirect further stdout output to our stderr
         new Thread("stdout reader for " + pythonExec) {
+          setDaemon(true)
           override def run() {
             scala.util.control.Exception.ignoring(classOf[IOException]) {
-              // FIXME HACK: We copy the stream on the level of bytes to
-              // attempt to dodge encoding problems.
-              var buf = new Array[Byte](1024)
+              // FIXME: We copy the stream on the level of bytes to avoid encoding problems.
+              val buf = new Array[Byte](1024)
               var len = in.read(buf)
               while (len != -1) {
                 System.err.write(buf, 0, len)

http://git-wip-us.apache.org/repos/asf/incubator-spark/blob/6550e5e6/python/pyspark/worker.py
----------------------------------------------------------------------
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index 695f6df..d63c2aa 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -21,6 +21,7 @@ Worker that receives input from Piped RDD.
 import os
 import sys
 import time
+import socket
 import traceback
 from base64 import standard_b64decode
 # CloudPickler needs to be imported so that depicklers are registered using the
@@ -94,7 +95,9 @@ def main(infile, outfile):
 
 
 if __name__ == '__main__':
-    # Redirect stdout to stderr so that users must return values from functions.
-    old_stdout = os.fdopen(os.dup(1), 'w')
-    os.dup2(2, 1)
-    main(sys.stdin, old_stdout)
+    # Read a local port to connect to from stdin
+    java_port = int(sys.stdin.readline())
+    sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+    sock.connect(("127.0.0.1", java_port))
+    sock_file = sock.makefile("a+", 65536)
+    main(sock_file, sock_file)