You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ma...@apache.org on 2013/11/27 05:55:59 UTC

[2/7] git commit: Remove Pickle-wrapping of Java objects in PySpark.

Remove Pickle-wrapping of Java objects in PySpark.

If we support custom serializers, the Python
worker will know what type of input to expect,
so we won't need to wrap Tuple2 and Strings into
pickled tuples and strings.

Project: http://git-wip-us.apache.org/repos/asf/incubator-spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-spark/commit/7d68a81a
Tree: http://git-wip-us.apache.org/repos/asf/incubator-spark/tree/7d68a81a
Diff: http://git-wip-us.apache.org/repos/asf/incubator-spark/diff/7d68a81a

Branch: refs/heads/master
Commit: 7d68a81a8ed5f49fefb3bd0fa0b9d3835cc7d86e
Parents: a48d88d
Author: Josh Rosen <jo...@apache.org>
Authored: Sun Nov 3 11:03:02 2013 -0800
Committer: Josh Rosen <jo...@apache.org>
Committed: Sun Nov 3 11:03:02 2013 -0800

----------------------------------------------------------------------
 .../org/apache/spark/api/python/PythonRDD.scala | 106 +++++++------------
 python/pyspark/context.py                       |  10 +-
 python/pyspark/rdd.py                           |  11 +-
 python/pyspark/serializers.py                   |  18 ++++
 python/pyspark/worker.py                        |  14 ++-
 5 files changed, 78 insertions(+), 81 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-spark/blob/7d68a81a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
index 0d5913e..eb0b0db 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
@@ -75,7 +75,7 @@ private[spark] class PythonRDD[T: ClassManifest](
           // Partition index
           dataOut.writeInt(split.index)
           // sparkFilesDir
-          PythonRDD.writeAsPickle(SparkFiles.getRootDirectory, dataOut)
+          dataOut.writeUTF(SparkFiles.getRootDirectory)
           // Broadcast variables
           dataOut.writeInt(broadcastVars.length)
           for (broadcast <- broadcastVars) {
@@ -85,9 +85,7 @@ private[spark] class PythonRDD[T: ClassManifest](
           }
           // Python includes (*.zip and *.egg files)
           dataOut.writeInt(pythonIncludes.length)
-          for (f <- pythonIncludes) {
-            PythonRDD.writeAsPickle(f, dataOut)
-          }
+          pythonIncludes.foreach(dataOut.writeUTF)
           dataOut.flush()
           // Serialized user code
           for (elem <- command) {
@@ -96,7 +94,7 @@ private[spark] class PythonRDD[T: ClassManifest](
           printOut.flush()
           // Data values
           for (elem <- parent.iterator(split, context)) {
-            PythonRDD.writeAsPickle(elem, dataOut)
+            PythonRDD.writeToStream(elem, dataOut)
           }
           dataOut.flush()
           printOut.flush()
@@ -205,60 +203,7 @@ private object SpecialLengths {
 
 private[spark] object PythonRDD {
 
-  /** Strips the pickle PROTO and STOP opcodes from the start and end of a pickle */
-  def stripPickle(arr: Array[Byte]) : Array[Byte] = {
-    arr.slice(2, arr.length - 1)
-  }
-
-  /**
-   * Write strings, pickled Python objects, or pairs of pickled objects to a data output stream.
-   * The data format is a 32-bit integer representing the pickled object's length (in bytes),
-   * followed by the pickled data.
-   *
-   * Pickle module:
-   *
-   *    http://docs.python.org/2/library/pickle.html
-   *
-   * The pickle protocol is documented in the source of the `pickle` and `pickletools` modules:
-   *
-   *    http://hg.python.org/cpython/file/2.6/Lib/pickle.py
-   *    http://hg.python.org/cpython/file/2.6/Lib/pickletools.py
-   *
-   * @param elem the object to write
-   * @param dOut a data output stream
-   */
-  def writeAsPickle(elem: Any, dOut: DataOutputStream) {
-    if (elem.isInstanceOf[Array[Byte]]) {
-      val arr = elem.asInstanceOf[Array[Byte]]
-      dOut.writeInt(arr.length)
-      dOut.write(arr)
-    } else if (elem.isInstanceOf[scala.Tuple2[Array[Byte], Array[Byte]]]) {
-      val t = elem.asInstanceOf[scala.Tuple2[Array[Byte], Array[Byte]]]
-      val length = t._1.length + t._2.length - 3 - 3 + 4  // stripPickle() removes 3 bytes
-      dOut.writeInt(length)
-      dOut.writeByte(Pickle.PROTO)
-      dOut.writeByte(Pickle.TWO)
-      dOut.write(PythonRDD.stripPickle(t._1))
-      dOut.write(PythonRDD.stripPickle(t._2))
-      dOut.writeByte(Pickle.TUPLE2)
-      dOut.writeByte(Pickle.STOP)
-    } else if (elem.isInstanceOf[String]) {
-      // For uniformity, strings are wrapped into Pickles.
-      val s = elem.asInstanceOf[String].getBytes("UTF-8")
-      val length = 2 + 1 + 4 + s.length + 1
-      dOut.writeInt(length)
-      dOut.writeByte(Pickle.PROTO)
-      dOut.writeByte(Pickle.TWO)
-      dOut.write(Pickle.BINUNICODE)
-      dOut.writeInt(Integer.reverseBytes(s.length))
-      dOut.write(s)
-      dOut.writeByte(Pickle.STOP)
-    } else {
-      throw new SparkException("Unexpected RDD type")
-    }
-  }
-
-  def readRDDFromPickleFile(sc: JavaSparkContext, filename: String, parallelism: Int) :
+  def readRDDFromFile(sc: JavaSparkContext, filename: String, parallelism: Int):
   JavaRDD[Array[Byte]] = {
     val file = new DataInputStream(new FileInputStream(filename))
     val objs = new collection.mutable.ArrayBuffer[Array[Byte]]
@@ -276,15 +221,46 @@ private[spark] object PythonRDD {
     JavaRDD.fromRDD(sc.sc.parallelize(objs, parallelism))
   }
 
-  def writeIteratorToPickleFile[T](items: java.util.Iterator[T], filename: String) {
+  def writeStringAsPickle(elem: String, dOut: DataOutputStream) {
+    val s = elem.getBytes("UTF-8")
+    val length = 2 + 1 + 4 + s.length + 1
+    dOut.writeInt(length)
+    dOut.writeByte(Pickle.PROTO)
+    dOut.writeByte(Pickle.TWO)
+    dOut.write(Pickle.BINUNICODE)
+    dOut.writeInt(Integer.reverseBytes(s.length))
+    dOut.write(s)
+    dOut.writeByte(Pickle.STOP)
+  }
+
+  def writeToStream(elem: Any, dataOut: DataOutputStream) {
+    elem match {
+      case bytes: Array[Byte] =>
+        dataOut.writeInt(bytes.length)
+        dataOut.write(bytes)
+      case pair: (Array[Byte], Array[Byte]) =>
+        dataOut.writeInt(pair._1.length)
+        dataOut.write(pair._1)
+        dataOut.writeInt(pair._2.length)
+        dataOut.write(pair._2)
+      case str: String =>
+        // Until we've implemented full custom serializer support, we need to return
+        // strings as Pickles to properly support union() and cartesian():
+        writeStringAsPickle(str, dataOut)
+      case other =>
+        throw new SparkException("Unexpected element type " + other.getClass)
+    }
+  }
+
+  def writeToFile[T](items: java.util.Iterator[T], filename: String) {
     import scala.collection.JavaConverters._
-    writeIteratorToPickleFile(items.asScala, filename)
+    writeToFile(items.asScala, filename)
   }
 
-  def writeIteratorToPickleFile[T](items: Iterator[T], filename: String) {
+  def writeToFile[T](items: Iterator[T], filename: String) {
     val file = new DataOutputStream(new FileOutputStream(filename))
     for (item <- items) {
-      writeAsPickle(item, file)
+      writeToStream(item, file)
     }
     file.close()
   }
@@ -300,10 +276,6 @@ private object Pickle {
   val TWO: Byte = 0x02.toByte
   val BINUNICODE: Byte = 'X'
   val STOP: Byte = '.'
-  val TUPLE2: Byte = 0x86.toByte
-  val EMPTY_LIST: Byte = ']'
-  val MARK: Byte = '('
-  val APPENDS: Byte = 'e'
 }
 
 private class BytesToString extends org.apache.spark.api.java.function.Function[Array[Byte], String] {

http://git-wip-us.apache.org/repos/asf/incubator-spark/blob/7d68a81a/python/pyspark/context.py
----------------------------------------------------------------------
diff --git a/python/pyspark/context.py b/python/pyspark/context.py
index a7ca8bc..0fec1a6 100644
--- a/python/pyspark/context.py
+++ b/python/pyspark/context.py
@@ -42,7 +42,7 @@ class SparkContext(object):
 
     _gateway = None
     _jvm = None
-    _writeIteratorToPickleFile = None
+    _writeToFile = None
     _takePartition = None
     _next_accum_id = 0
     _active_spark_context = None
@@ -125,8 +125,8 @@ class SparkContext(object):
             if not SparkContext._gateway:
                 SparkContext._gateway = launch_gateway()
                 SparkContext._jvm = SparkContext._gateway.jvm
-                SparkContext._writeIteratorToPickleFile = \
-                    SparkContext._jvm.PythonRDD.writeIteratorToPickleFile
+                SparkContext._writeToFile = \
+                    SparkContext._jvm.PythonRDD.writeToFile
                 SparkContext._takePartition = \
                     SparkContext._jvm.PythonRDD.takePartition
 
@@ -190,8 +190,8 @@ class SparkContext(object):
         for x in c:
             write_with_length(dump_pickle(x), tempFile)
         tempFile.close()
-        readRDDFromPickleFile = self._jvm.PythonRDD.readRDDFromPickleFile
-        jrdd = readRDDFromPickleFile(self._jsc, tempFile.name, numSlices)
+        readRDDFromFile = self._jvm.PythonRDD.readRDDFromFile
+        jrdd = readRDDFromFile(self._jsc, tempFile.name, numSlices)
         return RDD(jrdd, self)
 
     def textFile(self, name, minSplits=None):

http://git-wip-us.apache.org/repos/asf/incubator-spark/blob/7d68a81a/python/pyspark/rdd.py
----------------------------------------------------------------------
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
index 7019fb8..d3c4d13 100644
--- a/python/pyspark/rdd.py
+++ b/python/pyspark/rdd.py
@@ -54,6 +54,7 @@ class RDD(object):
         self.is_checkpointed = False
         self.ctx = ctx
         self._partitionFunc = None
+        self._stage_input_is_pairs = False
 
     @property
     def context(self):
@@ -344,6 +345,7 @@ class RDD(object):
                     yield pair
             else:
                 yield pair
+        java_cartesian._stage_input_is_pairs = True
         return java_cartesian.flatMap(unpack_batches)
 
     def groupBy(self, f, numPartitions=None):
@@ -391,8 +393,8 @@ class RDD(object):
         """
         Return a list that contains all of the elements in this RDD.
         """
-        picklesInJava = self._jrdd.collect().iterator()
-        return list(self._collect_iterator_through_file(picklesInJava))
+        bytesInJava = self._jrdd.collect().iterator()
+        return list(self._collect_iterator_through_file(bytesInJava))
 
     def _collect_iterator_through_file(self, iterator):
         # Transferring lots of data through Py4J can be slow because
@@ -400,7 +402,7 @@ class RDD(object):
         # file and read it back.
         tempFile = NamedTemporaryFile(delete=False, dir=self.ctx._temp_dir)
         tempFile.close()
-        self.ctx._writeIteratorToPickleFile(iterator, tempFile.name)
+        self.ctx._writeToFile(iterator, tempFile.name)
         # Read the data into Python and deserialize it:
         with open(tempFile.name, 'rb') as tempFile:
             for item in read_from_pickle_file(tempFile):
@@ -941,6 +943,7 @@ class PipelinedRDD(RDD):
             self.func = func
             self.preservesPartitioning = preservesPartitioning
             self._prev_jrdd = prev._jrdd
+        self._stage_input_is_pairs = prev._stage_input_is_pairs
         self.is_cached = False
         self.is_checkpointed = False
         self.ctx = prev.ctx
@@ -959,7 +962,7 @@ class PipelinedRDD(RDD):
             def batched_func(split, iterator):
                 return batched(oldfunc(split, iterator), batchSize)
             func = batched_func
-        cmds = [func, self._bypass_serializer]
+        cmds = [func, self._bypass_serializer, self._stage_input_is_pairs]
         pipe_command = ' '.join(b64enc(cloudpickle.dumps(f)) for f in cmds)
         broadcast_vars = ListConverter().convert(
             [x._jbroadcast for x in self.ctx._pickled_broadcast_vars],

http://git-wip-us.apache.org/repos/asf/incubator-spark/blob/7d68a81a/python/pyspark/serializers.py
----------------------------------------------------------------------
diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py
index fbc280f..fd02e1e 100644
--- a/python/pyspark/serializers.py
+++ b/python/pyspark/serializers.py
@@ -93,6 +93,14 @@ def write_with_length(obj, stream):
     stream.write(obj)
 
 
+def read_mutf8(stream):
+    """
+    Read a string written with Java's DataOutputStream.writeUTF() method.
+    """
+    length = struct.unpack('>H', stream.read(2))[0]
+    return stream.read(length).decode('utf8')
+
+
 def read_with_length(stream):
     length = read_int(stream)
     obj = stream.read(length)
@@ -112,3 +120,13 @@ def read_from_pickle_file(stream):
                 yield obj
     except EOFError:
         return
+
+
+def read_pairs_from_pickle_file(stream):
+    try:
+        while True:
+            a = load_pickle(read_with_length(stream))
+            b = load_pickle(read_with_length(stream))
+            yield (a, b)
+    except EOFError:
+        return
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/incubator-spark/blob/7d68a81a/python/pyspark/worker.py
----------------------------------------------------------------------
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index 7696df9..4e64557 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -31,8 +31,8 @@ from pyspark.broadcast import Broadcast, _broadcastRegistry
 from pyspark.cloudpickle import CloudPickler
 from pyspark.files import SparkFiles
 from pyspark.serializers import write_with_length, read_with_length, write_int, \
-    read_long, write_long, read_int, dump_pickle, load_pickle, read_from_pickle_file \
-    SpecialLengths
+    read_long, write_long, read_int, dump_pickle, load_pickle, read_from_pickle_file, \
+    SpecialLengths, read_mutf8, read_pairs_from_pickle_file
 
 
 def load_obj(infile):
@@ -53,7 +53,7 @@ def main(infile, outfile):
         return
 
     # fetch name of workdir
-    spark_files_dir = load_pickle(read_with_length(infile))
+    spark_files_dir = read_mutf8(infile)
     SparkFiles._root_directory = spark_files_dir
     SparkFiles._is_running_on_worker = True
 
@@ -68,17 +68,21 @@ def main(infile, outfile):
     sys.path.append(spark_files_dir) # *.py files that were added will be copied here
     num_python_includes =  read_int(infile)
     for _ in range(num_python_includes):
-        sys.path.append(os.path.join(spark_files_dir, load_pickle(read_with_length(infile))))
+        sys.path.append(os.path.join(spark_files_dir, read_mutf8(infile)))
 
     # now load function
     func = load_obj(infile)
     bypassSerializer = load_obj(infile)
+    stageInputIsPairs = load_obj(infile)
     if bypassSerializer:
         dumps = lambda x: x
     else:
         dumps = dump_pickle
     init_time = time.time()
-    iterator = read_from_pickle_file(infile)
+    if stageInputIsPairs:
+        iterator = read_pairs_from_pickle_file(infile)
+    else:
+        iterator = read_from_pickle_file(infile)
     try:
         for obj in func(split_index, iterator):
             write_with_length(dumps(obj), outfile)