You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by jo...@apache.org on 2015/05/18 21:55:20 UTC

spark git commit: [SPARK-6216] [PYSPARK] check python version of worker with driver

Repository: spark
Updated Branches:
  refs/heads/master 9dadf019b -> 32fbd297d


[SPARK-6216] [PYSPARK] check python version of worker with driver

This PR revert #5404, change to pass the version of python in driver into JVM, check it in worker before deserializing closure, then it can works with different major version of Python.

Author: Davies Liu <da...@databricks.com>

Closes #6203 from davies/py_version and squashes the following commits:

b8fb76e [Davies Liu] fix test
6ce5096 [Davies Liu] use string for version
47c6278 [Davies Liu] check python version of worker with driver


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/32fbd297
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/32fbd297
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/32fbd297

Branch: refs/heads/master
Commit: 32fbd297dd651ba3ce4ce52aeb0488233149cdf9
Parents: 9dadf01
Author: Davies Liu <da...@databricks.com>
Authored: Mon May 18 12:55:13 2015 -0700
Committer: Josh Rosen <jo...@databricks.com>
Committed: Mon May 18 12:55:13 2015 -0700

----------------------------------------------------------------------
 .../scala/org/apache/spark/api/python/PythonRDD.scala   |  3 +++
 python/pyspark/context.py                               |  1 +
 python/pyspark/rdd.py                                   |  4 ++--
 python/pyspark/sql/context.py                           |  1 +
 python/pyspark/sql/functions.py                         |  4 ++--
 python/pyspark/tests.py                                 |  6 +++---
 python/pyspark/worker.py                                | 12 +++++++-----
 .../scala/org/apache/spark/sql/UDFRegistration.scala    |  2 ++
 .../org/apache/spark/sql/UserDefinedFunction.scala      |  5 +++--
 .../org/apache/spark/sql/execution/pythonUdfs.scala     |  2 ++
 10 files changed, 26 insertions(+), 14 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/32fbd297/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
index 7409dc2..2d92f6a 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
@@ -47,6 +47,7 @@ private[spark] class PythonRDD(
     pythonIncludes: JList[String],
     preservePartitoning: Boolean,
     pythonExec: String,
+    pythonVer: String,
     broadcastVars: JList[Broadcast[PythonBroadcast]],
     accumulator: Accumulator[JList[Array[Byte]]])
   extends RDD[Array[Byte]](parent) {
@@ -210,6 +211,8 @@ private[spark] class PythonRDD(
         val dataOut = new DataOutputStream(stream)
         // Partition index
         dataOut.writeInt(split.index)
+        // Python version of driver
+        PythonRDD.writeUTF(pythonVer, dataOut)
         // sparkFilesDir
         PythonRDD.writeUTF(SparkFiles.getRootDirectory, dataOut)
         // Python includes (*.zip and *.egg files)

http://git-wip-us.apache.org/repos/asf/spark/blob/32fbd297/python/pyspark/context.py
----------------------------------------------------------------------
diff --git a/python/pyspark/context.py b/python/pyspark/context.py
index 3199279..d25ee85 100644
--- a/python/pyspark/context.py
+++ b/python/pyspark/context.py
@@ -173,6 +173,7 @@ class SparkContext(object):
             self._jvm.PythonAccumulatorParam(host, port))
 
         self.pythonExec = os.environ.get("PYSPARK_PYTHON", 'python')
+        self.pythonVer = "%d.%d" % sys.version_info[:2]
 
         # Broadcast's __reduce__ method stores Broadcast instances here.
         # This allows other code to determine which Broadcast instances have

http://git-wip-us.apache.org/repos/asf/spark/blob/32fbd297/python/pyspark/rdd.py
----------------------------------------------------------------------
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
index 545c5ad..70db4bb 100644
--- a/python/pyspark/rdd.py
+++ b/python/pyspark/rdd.py
@@ -2260,7 +2260,7 @@ class RDD(object):
 def _prepare_for_python_RDD(sc, command, obj=None):
     # the serialized command will be compressed by broadcast
     ser = CloudPickleSerializer()
-    pickled_command = ser.dumps((command, sys.version_info[:2]))
+    pickled_command = ser.dumps(command)
     if len(pickled_command) > (1 << 20):  # 1M
         # The broadcast will have same life cycle as created PythonRDD
         broadcast = sc.broadcast(pickled_command)
@@ -2344,7 +2344,7 @@ class PipelinedRDD(RDD):
         python_rdd = self.ctx._jvm.PythonRDD(self._prev_jrdd.rdd(),
                                              bytearray(pickled_cmd),
                                              env, includes, self.preservesPartitioning,
-                                             self.ctx.pythonExec,
+                                             self.ctx.pythonExec, self.ctx.pythonVer,
                                              bvars, self.ctx._javaAccumulator)
         self._jrdd_val = python_rdd.asJavaRDD()
 

http://git-wip-us.apache.org/repos/asf/spark/blob/32fbd297/python/pyspark/sql/context.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py
index f6f107c..0bde719 100644
--- a/python/pyspark/sql/context.py
+++ b/python/pyspark/sql/context.py
@@ -157,6 +157,7 @@ class SQLContext(object):
                                             env,
                                             includes,
                                             self._sc.pythonExec,
+                                            self._sc.pythonVer,
                                             bvars,
                                             self._sc._javaAccumulator,
                                             returnType.json())

http://git-wip-us.apache.org/repos/asf/spark/blob/32fbd297/python/pyspark/sql/functions.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py
index 8d0e766..fbe9bf5 100644
--- a/python/pyspark/sql/functions.py
+++ b/python/pyspark/sql/functions.py
@@ -353,8 +353,8 @@ class UserDefinedFunction(object):
         ssql_ctx = sc._jvm.SQLContext(sc._jsc.sc())
         jdt = ssql_ctx.parseDataType(self.returnType.json())
         fname = f.__name__ if hasattr(f, '__name__') else f.__class__.__name__
-        judf = sc._jvm.UserDefinedPythonFunction(fname, bytearray(pickled_command), env,
-                                                 includes, sc.pythonExec, broadcast_vars,
+        judf = sc._jvm.UserDefinedPythonFunction(fname, bytearray(pickled_command), env, includes,
+                                                 sc.pythonExec, sc.pythonVer, broadcast_vars,
                                                  sc._javaAccumulator, jdt)
         return judf
 

http://git-wip-us.apache.org/repos/asf/spark/blob/32fbd297/python/pyspark/tests.py
----------------------------------------------------------------------
diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py
index 09de4d1..5e023f6 100644
--- a/python/pyspark/tests.py
+++ b/python/pyspark/tests.py
@@ -1543,13 +1543,13 @@ class WorkerTests(ReusedPySparkTestCase):
     def test_with_different_versions_of_python(self):
         rdd = self.sc.parallelize(range(10))
         rdd.count()
-        version = sys.version_info
-        sys.version_info = (2, 0, 0)
+        version = self.sc.pythonVer
+        self.sc.pythonVer = "2.0"
         try:
             with QuietTest(self.sc):
                 self.assertRaises(Py4JJavaError, lambda: rdd.count())
         finally:
-            sys.version_info = version
+            self.sc.pythonVer = version
 
 
 class SparkSubmitTests(unittest.TestCase):

http://git-wip-us.apache.org/repos/asf/spark/blob/32fbd297/python/pyspark/worker.py
----------------------------------------------------------------------
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index fbdaf3a..93df900 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -57,6 +57,12 @@ def main(infile, outfile):
         if split_index == -1:  # for unit tests
             exit(-1)
 
+        version = utf8_deserializer.loads(infile)
+        if version != "%d.%d" % sys.version_info[:2]:
+            raise Exception(("Python in worker has different version %s than that in " +
+                             "driver %s, PySpark cannot run with different minor versions") %
+                            ("%d.%d" % sys.version_info[:2], version))
+
         # initialize global state
         shuffle.MemoryBytesSpilled = 0
         shuffle.DiskBytesSpilled = 0
@@ -92,11 +98,7 @@ def main(infile, outfile):
         command = pickleSer._read_with_length(infile)
         if isinstance(command, Broadcast):
             command = pickleSer.loads(command.value)
-        (func, profiler, deserializer, serializer), version = command
-        if version != sys.version_info[:2]:
-            raise Exception(("Python in worker has different version %s than that in " +
-                            "driver %s, PySpark cannot run with different minor versions") %
-                            (sys.version_info[:2], version))
+        func, profiler, deserializer, serializer = command
         init_time = time.time()
 
         def process():

http://git-wip-us.apache.org/repos/asf/spark/blob/32fbd297/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala
index dc3389c..3cc5c24 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala
@@ -46,6 +46,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging {
       envVars: JMap[String, String],
       pythonIncludes: JList[String],
       pythonExec: String,
+      pythonVer: String,
       broadcastVars: JList[Broadcast[PythonBroadcast]],
       accumulator: Accumulator[JList[Array[Byte]]],
       stringDataType: String): Unit = {
@@ -70,6 +71,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging {
         envVars,
         pythonIncludes,
         pythonExec,
+        pythonVer,
         broadcastVars,
         accumulator,
         dataType,

http://git-wip-us.apache.org/repos/asf/spark/blob/32fbd297/sql/core/src/main/scala/org/apache/spark/sql/UserDefinedFunction.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UserDefinedFunction.scala b/sql/core/src/main/scala/org/apache/spark/sql/UserDefinedFunction.scala
index 505ab13..a02e202 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/UserDefinedFunction.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/UserDefinedFunction.scala
@@ -58,14 +58,15 @@ private[sql] case class UserDefinedPythonFunction(
     envVars: JMap[String, String],
     pythonIncludes: JList[String],
     pythonExec: String,
+    pythonVer: String,
     broadcastVars: JList[Broadcast[PythonBroadcast]],
     accumulator: Accumulator[JList[Array[Byte]]],
     dataType: DataType) {
 
   /** Returns a [[Column]] that will evaluate to calling this UDF with the given input. */
   def apply(exprs: Column*): Column = {
-    val udf = PythonUDF(name, command, envVars, pythonIncludes, pythonExec, broadcastVars,
-      accumulator, dataType, exprs.map(_.expr))
+    val udf = PythonUDF(name, command, envVars, pythonIncludes, pythonExec, pythonVer,
+      broadcastVars, accumulator, dataType, exprs.map(_.expr))
     Column(udf)
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/32fbd297/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala
index 65dd7ba..11b2897 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala
@@ -46,6 +46,7 @@ private[spark] case class PythonUDF(
     envVars: JMap[String, String],
     pythonIncludes: JList[String],
     pythonExec: String,
+    pythonVer: String,
     broadcastVars: JList[Broadcast[PythonBroadcast]],
     accumulator: Accumulator[JList[Array[Byte]]],
     dataType: DataType,
@@ -251,6 +252,7 @@ case class BatchPythonEvaluation(udf: PythonUDF, output: Seq[Attribute], child:
       udf.pythonIncludes,
       false,
       udf.pythonExec,
+      udf.pythonVer,
       udf.broadcastVars,
       udf.accumulator
     ).mapPartitions { iter =>


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org