You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ma...@apache.org on 2013/11/27 05:55:58 UTC

[1/7] git commit: Replace magic lengths with constants in PySpark.

Updated Branches:
  refs/heads/master 330ada176 -> fb6875dd5


Replace magic lengths with constants in PySpark.

Write the length of the accumulators section up-front rather
than terminating it with a negative length.  I find this
easier to read.


Project: http://git-wip-us.apache.org/repos/asf/incubator-spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-spark/commit/a48d88d2
Tree: http://git-wip-us.apache.org/repos/asf/incubator-spark/tree/a48d88d2
Diff: http://git-wip-us.apache.org/repos/asf/incubator-spark/diff/a48d88d2

Branch: refs/heads/master
Commit: a48d88d206fae348720ab077a624b3c57293374f
Parents: 41ead7a
Author: Josh Rosen <jo...@apache.org>
Authored: Sat Nov 2 21:13:18 2013 -0700
Committer: Josh Rosen <jo...@apache.org>
Committed: Sun Nov 3 10:54:24 2013 -0800

----------------------------------------------------------------------
 .../org/apache/spark/api/python/PythonRDD.scala | 26 ++++++++++++--------
 python/pyspark/serializers.py                   |  6 +++++
 python/pyspark/worker.py                        | 13 +++++-----
 3 files changed, 29 insertions(+), 16 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-spark/blob/a48d88d2/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
index 12b4d94..0d5913e 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
@@ -132,7 +132,7 @@ private[spark] class PythonRDD[T: ClassManifest](
               val obj = new Array[Byte](length)
               stream.readFully(obj)
               obj
-            case -3 =>
+            case SpecialLengths.TIMING_DATA =>
               // Timing data from worker
               val bootTime = stream.readLong()
               val initTime = stream.readLong()
@@ -143,24 +143,24 @@ private[spark] class PythonRDD[T: ClassManifest](
               val total = finishTime - startTime
               logInfo("Times: total = %s, boot = %s, init = %s, finish = %s".format(total, boot, init, finish))
               read
-            case -2 =>
+            case SpecialLengths.PYTHON_EXCEPTION_THROWN =>
               // Signals that an exception has been thrown in python
               val exLength = stream.readInt()
               val obj = new Array[Byte](exLength)
               stream.readFully(obj)
               throw new PythonException(new String(obj))
-            case -1 =>
+            case SpecialLengths.END_OF_DATA_SECTION =>
               // We've finished the data section of the output, but we can still
-              // read some accumulator updates; let's do that, breaking when we
-              // get a negative length record.
-              var len2 = stream.readInt()
-              while (len2 >= 0) {
-                val update = new Array[Byte](len2)
+              // read some accumulator updates:
+              val numAccumulatorUpdates = stream.readInt()
+              (1 to numAccumulatorUpdates).foreach { _ =>
+                val updateLen = stream.readInt()
+                val update = new Array[Byte](updateLen)
                 stream.readFully(update)
                 accumulator += Collections.singletonList(update)
-                len2 = stream.readInt()
+
               }
-              new Array[Byte](0)
+              Array.empty[Byte]
           }
         } catch {
           case eof: EOFException => {
@@ -197,6 +197,12 @@ private class PairwiseRDD(prev: RDD[Array[Byte]]) extends
   val asJavaPairRDD : JavaPairRDD[Long, Array[Byte]] = JavaPairRDD.fromRDD(this)
 }
 
+private object SpecialLengths {
+  val END_OF_DATA_SECTION = -1
+  val PYTHON_EXCEPTION_THROWN = -2
+  val TIMING_DATA = -3
+}
+
 private[spark] object PythonRDD {
 
   /** Strips the pickle PROTO and STOP opcodes from the start and end of a pickle */

http://git-wip-us.apache.org/repos/asf/incubator-spark/blob/a48d88d2/python/pyspark/serializers.py
----------------------------------------------------------------------
diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py
index 54fed1c..fbc280f 100644
--- a/python/pyspark/serializers.py
+++ b/python/pyspark/serializers.py
@@ -19,6 +19,12 @@ import struct
 import cPickle
 
 
+class SpecialLengths(object):
+    END_OF_DATA_SECTION = -1
+    PYTHON_EXCEPTION_THROWN = -2
+    TIMING_DATA = -3
+
+
 class Batch(object):
     """
     Used to store multiple RDD entries as a single Java object.

http://git-wip-us.apache.org/repos/asf/incubator-spark/blob/a48d88d2/python/pyspark/worker.py
----------------------------------------------------------------------
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index d63c2aa..7696df9 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -31,7 +31,8 @@ from pyspark.broadcast import Broadcast, _broadcastRegistry
 from pyspark.cloudpickle import CloudPickler
 from pyspark.files import SparkFiles
 from pyspark.serializers import write_with_length, read_with_length, write_int, \
-    read_long, write_long, read_int, dump_pickle, load_pickle, read_from_pickle_file
+    read_long, write_long, read_int, dump_pickle, load_pickle, read_from_pickle_file \
+    SpecialLengths
 
 
 def load_obj(infile):
@@ -39,7 +40,7 @@ def load_obj(infile):
 
 
 def report_times(outfile, boot, init, finish):
-    write_int(-3, outfile)
+    write_int(SpecialLengths.TIMING_DATA, outfile)
     write_long(1000 * boot, outfile)
     write_long(1000 * init, outfile)
     write_long(1000 * finish, outfile)
@@ -82,16 +83,16 @@ def main(infile, outfile):
         for obj in func(split_index, iterator):
             write_with_length(dumps(obj), outfile)
     except Exception as e:
-        write_int(-2, outfile)
+        write_int(SpecialLengths.PYTHON_EXCEPTION_THROWN, outfile)
         write_with_length(traceback.format_exc(), outfile)
         sys.exit(-1)
     finish_time = time.time()
     report_times(outfile, boot_time, init_time, finish_time)
     # Mark the beginning of the accumulators section of the output
-    write_int(-1, outfile)
-    for aid, accum in _accumulatorRegistry.items():
+    write_int(SpecialLengths.END_OF_DATA_SECTION, outfile)
+    write_int(len(_accumulatorRegistry), outfile)
+    for (aid, accum) in _accumulatorRegistry.items():
         write_with_length(dump_pickle((aid, accum._value)), outfile)
-    write_int(-1, outfile)
 
 
 if __name__ == '__main__':


[2/7] git commit: Remove Pickle-wrapping of Java objects in PySpark.

Posted by ma...@apache.org.
Remove Pickle-wrapping of Java objects in PySpark.

If we support custom serializers, the Python
worker will know what type of input to expect,
so we won't need to wrap Tuple2 and Strings into
pickled tuples and strings.

Project: http://git-wip-us.apache.org/repos/asf/incubator-spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-spark/commit/7d68a81a
Tree: http://git-wip-us.apache.org/repos/asf/incubator-spark/tree/7d68a81a
Diff: http://git-wip-us.apache.org/repos/asf/incubator-spark/diff/7d68a81a

Branch: refs/heads/master
Commit: 7d68a81a8ed5f49fefb3bd0fa0b9d3835cc7d86e
Parents: a48d88d
Author: Josh Rosen <jo...@apache.org>
Authored: Sun Nov 3 11:03:02 2013 -0800
Committer: Josh Rosen <jo...@apache.org>
Committed: Sun Nov 3 11:03:02 2013 -0800

----------------------------------------------------------------------
 .../org/apache/spark/api/python/PythonRDD.scala | 106 +++++++------------
 python/pyspark/context.py                       |  10 +-
 python/pyspark/rdd.py                           |  11 +-
 python/pyspark/serializers.py                   |  18 ++++
 python/pyspark/worker.py                        |  14 ++-
 5 files changed, 78 insertions(+), 81 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-spark/blob/7d68a81a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
index 0d5913e..eb0b0db 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
@@ -75,7 +75,7 @@ private[spark] class PythonRDD[T: ClassManifest](
           // Partition index
           dataOut.writeInt(split.index)
           // sparkFilesDir
-          PythonRDD.writeAsPickle(SparkFiles.getRootDirectory, dataOut)
+          dataOut.writeUTF(SparkFiles.getRootDirectory)
           // Broadcast variables
           dataOut.writeInt(broadcastVars.length)
           for (broadcast <- broadcastVars) {
@@ -85,9 +85,7 @@ private[spark] class PythonRDD[T: ClassManifest](
           }
           // Python includes (*.zip and *.egg files)
           dataOut.writeInt(pythonIncludes.length)
-          for (f <- pythonIncludes) {
-            PythonRDD.writeAsPickle(f, dataOut)
-          }
+          pythonIncludes.foreach(dataOut.writeUTF)
           dataOut.flush()
           // Serialized user code
           for (elem <- command) {
@@ -96,7 +94,7 @@ private[spark] class PythonRDD[T: ClassManifest](
           printOut.flush()
           // Data values
           for (elem <- parent.iterator(split, context)) {
-            PythonRDD.writeAsPickle(elem, dataOut)
+            PythonRDD.writeToStream(elem, dataOut)
           }
           dataOut.flush()
           printOut.flush()
@@ -205,60 +203,7 @@ private object SpecialLengths {
 
 private[spark] object PythonRDD {
 
-  /** Strips the pickle PROTO and STOP opcodes from the start and end of a pickle */
-  def stripPickle(arr: Array[Byte]) : Array[Byte] = {
-    arr.slice(2, arr.length - 1)
-  }
-
-  /**
-   * Write strings, pickled Python objects, or pairs of pickled objects to a data output stream.
-   * The data format is a 32-bit integer representing the pickled object's length (in bytes),
-   * followed by the pickled data.
-   *
-   * Pickle module:
-   *
-   *    http://docs.python.org/2/library/pickle.html
-   *
-   * The pickle protocol is documented in the source of the `pickle` and `pickletools` modules:
-   *
-   *    http://hg.python.org/cpython/file/2.6/Lib/pickle.py
-   *    http://hg.python.org/cpython/file/2.6/Lib/pickletools.py
-   *
-   * @param elem the object to write
-   * @param dOut a data output stream
-   */
-  def writeAsPickle(elem: Any, dOut: DataOutputStream) {
-    if (elem.isInstanceOf[Array[Byte]]) {
-      val arr = elem.asInstanceOf[Array[Byte]]
-      dOut.writeInt(arr.length)
-      dOut.write(arr)
-    } else if (elem.isInstanceOf[scala.Tuple2[Array[Byte], Array[Byte]]]) {
-      val t = elem.asInstanceOf[scala.Tuple2[Array[Byte], Array[Byte]]]
-      val length = t._1.length + t._2.length - 3 - 3 + 4  // stripPickle() removes 3 bytes
-      dOut.writeInt(length)
-      dOut.writeByte(Pickle.PROTO)
-      dOut.writeByte(Pickle.TWO)
-      dOut.write(PythonRDD.stripPickle(t._1))
-      dOut.write(PythonRDD.stripPickle(t._2))
-      dOut.writeByte(Pickle.TUPLE2)
-      dOut.writeByte(Pickle.STOP)
-    } else if (elem.isInstanceOf[String]) {
-      // For uniformity, strings are wrapped into Pickles.
-      val s = elem.asInstanceOf[String].getBytes("UTF-8")
-      val length = 2 + 1 + 4 + s.length + 1
-      dOut.writeInt(length)
-      dOut.writeByte(Pickle.PROTO)
-      dOut.writeByte(Pickle.TWO)
-      dOut.write(Pickle.BINUNICODE)
-      dOut.writeInt(Integer.reverseBytes(s.length))
-      dOut.write(s)
-      dOut.writeByte(Pickle.STOP)
-    } else {
-      throw new SparkException("Unexpected RDD type")
-    }
-  }
-
-  def readRDDFromPickleFile(sc: JavaSparkContext, filename: String, parallelism: Int) :
+  def readRDDFromFile(sc: JavaSparkContext, filename: String, parallelism: Int):
   JavaRDD[Array[Byte]] = {
     val file = new DataInputStream(new FileInputStream(filename))
     val objs = new collection.mutable.ArrayBuffer[Array[Byte]]
@@ -276,15 +221,46 @@ private[spark] object PythonRDD {
     JavaRDD.fromRDD(sc.sc.parallelize(objs, parallelism))
   }
 
-  def writeIteratorToPickleFile[T](items: java.util.Iterator[T], filename: String) {
+  def writeStringAsPickle(elem: String, dOut: DataOutputStream) {
+    val s = elem.getBytes("UTF-8")
+    val length = 2 + 1 + 4 + s.length + 1
+    dOut.writeInt(length)
+    dOut.writeByte(Pickle.PROTO)
+    dOut.writeByte(Pickle.TWO)
+    dOut.write(Pickle.BINUNICODE)
+    dOut.writeInt(Integer.reverseBytes(s.length))
+    dOut.write(s)
+    dOut.writeByte(Pickle.STOP)
+  }
+
+  def writeToStream(elem: Any, dataOut: DataOutputStream) {
+    elem match {
+      case bytes: Array[Byte] =>
+        dataOut.writeInt(bytes.length)
+        dataOut.write(bytes)
+      case pair: (Array[Byte], Array[Byte]) =>
+        dataOut.writeInt(pair._1.length)
+        dataOut.write(pair._1)
+        dataOut.writeInt(pair._2.length)
+        dataOut.write(pair._2)
+      case str: String =>
+        // Until we've implemented full custom serializer support, we need to return
+        // strings as Pickles to properly support union() and cartesian():
+        writeStringAsPickle(str, dataOut)
+      case other =>
+        throw new SparkException("Unexpected element type " + other.getClass)
+    }
+  }
+
+  def writeToFile[T](items: java.util.Iterator[T], filename: String) {
     import scala.collection.JavaConverters._
-    writeIteratorToPickleFile(items.asScala, filename)
+    writeToFile(items.asScala, filename)
   }
 
-  def writeIteratorToPickleFile[T](items: Iterator[T], filename: String) {
+  def writeToFile[T](items: Iterator[T], filename: String) {
     val file = new DataOutputStream(new FileOutputStream(filename))
     for (item <- items) {
-      writeAsPickle(item, file)
+      writeToStream(item, file)
     }
     file.close()
   }
@@ -300,10 +276,6 @@ private object Pickle {
   val TWO: Byte = 0x02.toByte
   val BINUNICODE: Byte = 'X'
   val STOP: Byte = '.'
-  val TUPLE2: Byte = 0x86.toByte
-  val EMPTY_LIST: Byte = ']'
-  val MARK: Byte = '('
-  val APPENDS: Byte = 'e'
 }
 
 private class BytesToString extends org.apache.spark.api.java.function.Function[Array[Byte], String] {

http://git-wip-us.apache.org/repos/asf/incubator-spark/blob/7d68a81a/python/pyspark/context.py
----------------------------------------------------------------------
diff --git a/python/pyspark/context.py b/python/pyspark/context.py
index a7ca8bc..0fec1a6 100644
--- a/python/pyspark/context.py
+++ b/python/pyspark/context.py
@@ -42,7 +42,7 @@ class SparkContext(object):
 
     _gateway = None
     _jvm = None
-    _writeIteratorToPickleFile = None
+    _writeToFile = None
     _takePartition = None
     _next_accum_id = 0
     _active_spark_context = None
@@ -125,8 +125,8 @@ class SparkContext(object):
             if not SparkContext._gateway:
                 SparkContext._gateway = launch_gateway()
                 SparkContext._jvm = SparkContext._gateway.jvm
-                SparkContext._writeIteratorToPickleFile = \
-                    SparkContext._jvm.PythonRDD.writeIteratorToPickleFile
+                SparkContext._writeToFile = \
+                    SparkContext._jvm.PythonRDD.writeToFile
                 SparkContext._takePartition = \
                     SparkContext._jvm.PythonRDD.takePartition
 
@@ -190,8 +190,8 @@ class SparkContext(object):
         for x in c:
             write_with_length(dump_pickle(x), tempFile)
         tempFile.close()
-        readRDDFromPickleFile = self._jvm.PythonRDD.readRDDFromPickleFile
-        jrdd = readRDDFromPickleFile(self._jsc, tempFile.name, numSlices)
+        readRDDFromFile = self._jvm.PythonRDD.readRDDFromFile
+        jrdd = readRDDFromFile(self._jsc, tempFile.name, numSlices)
         return RDD(jrdd, self)
 
     def textFile(self, name, minSplits=None):

http://git-wip-us.apache.org/repos/asf/incubator-spark/blob/7d68a81a/python/pyspark/rdd.py
----------------------------------------------------------------------
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
index 7019fb8..d3c4d13 100644
--- a/python/pyspark/rdd.py
+++ b/python/pyspark/rdd.py
@@ -54,6 +54,7 @@ class RDD(object):
         self.is_checkpointed = False
         self.ctx = ctx
         self._partitionFunc = None
+        self._stage_input_is_pairs = False
 
     @property
     def context(self):
@@ -344,6 +345,7 @@ class RDD(object):
                     yield pair
             else:
                 yield pair
+        java_cartesian._stage_input_is_pairs = True
         return java_cartesian.flatMap(unpack_batches)
 
     def groupBy(self, f, numPartitions=None):
@@ -391,8 +393,8 @@ class RDD(object):
         """
         Return a list that contains all of the elements in this RDD.
         """
-        picklesInJava = self._jrdd.collect().iterator()
-        return list(self._collect_iterator_through_file(picklesInJava))
+        bytesInJava = self._jrdd.collect().iterator()
+        return list(self._collect_iterator_through_file(bytesInJava))
 
     def _collect_iterator_through_file(self, iterator):
         # Transferring lots of data through Py4J can be slow because
@@ -400,7 +402,7 @@ class RDD(object):
         # file and read it back.
         tempFile = NamedTemporaryFile(delete=False, dir=self.ctx._temp_dir)
         tempFile.close()
-        self.ctx._writeIteratorToPickleFile(iterator, tempFile.name)
+        self.ctx._writeToFile(iterator, tempFile.name)
         # Read the data into Python and deserialize it:
         with open(tempFile.name, 'rb') as tempFile:
             for item in read_from_pickle_file(tempFile):
@@ -941,6 +943,7 @@ class PipelinedRDD(RDD):
             self.func = func
             self.preservesPartitioning = preservesPartitioning
             self._prev_jrdd = prev._jrdd
+        self._stage_input_is_pairs = prev._stage_input_is_pairs
         self.is_cached = False
         self.is_checkpointed = False
         self.ctx = prev.ctx
@@ -959,7 +962,7 @@ class PipelinedRDD(RDD):
             def batched_func(split, iterator):
                 return batched(oldfunc(split, iterator), batchSize)
             func = batched_func
-        cmds = [func, self._bypass_serializer]
+        cmds = [func, self._bypass_serializer, self._stage_input_is_pairs]
         pipe_command = ' '.join(b64enc(cloudpickle.dumps(f)) for f in cmds)
         broadcast_vars = ListConverter().convert(
             [x._jbroadcast for x in self.ctx._pickled_broadcast_vars],

http://git-wip-us.apache.org/repos/asf/incubator-spark/blob/7d68a81a/python/pyspark/serializers.py
----------------------------------------------------------------------
diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py
index fbc280f..fd02e1e 100644
--- a/python/pyspark/serializers.py
+++ b/python/pyspark/serializers.py
@@ -93,6 +93,14 @@ def write_with_length(obj, stream):
     stream.write(obj)
 
 
+def read_mutf8(stream):
+    """
+    Read a string written with Java's DataOutputStream.writeUTF() method.
+    """
+    length = struct.unpack('>H', stream.read(2))[0]
+    return stream.read(length).decode('utf8')
+
+
 def read_with_length(stream):
     length = read_int(stream)
     obj = stream.read(length)
@@ -112,3 +120,13 @@ def read_from_pickle_file(stream):
                 yield obj
     except EOFError:
         return
+
+
+def read_pairs_from_pickle_file(stream):
+    try:
+        while True:
+            a = load_pickle(read_with_length(stream))
+            b = load_pickle(read_with_length(stream))
+            yield (a, b)
+    except EOFError:
+        return
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/incubator-spark/blob/7d68a81a/python/pyspark/worker.py
----------------------------------------------------------------------
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index 7696df9..4e64557 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -31,8 +31,8 @@ from pyspark.broadcast import Broadcast, _broadcastRegistry
 from pyspark.cloudpickle import CloudPickler
 from pyspark.files import SparkFiles
 from pyspark.serializers import write_with_length, read_with_length, write_int, \
-    read_long, write_long, read_int, dump_pickle, load_pickle, read_from_pickle_file \
-    SpecialLengths
+    read_long, write_long, read_int, dump_pickle, load_pickle, read_from_pickle_file, \
+    SpecialLengths, read_mutf8, read_pairs_from_pickle_file
 
 
 def load_obj(infile):
@@ -53,7 +53,7 @@ def main(infile, outfile):
         return
 
     # fetch name of workdir
-    spark_files_dir = load_pickle(read_with_length(infile))
+    spark_files_dir = read_mutf8(infile)
     SparkFiles._root_directory = spark_files_dir
     SparkFiles._is_running_on_worker = True
 
@@ -68,17 +68,21 @@ def main(infile, outfile):
     sys.path.append(spark_files_dir) # *.py files that were added will be copied here
     num_python_includes =  read_int(infile)
     for _ in range(num_python_includes):
-        sys.path.append(os.path.join(spark_files_dir, load_pickle(read_with_length(infile))))
+        sys.path.append(os.path.join(spark_files_dir, read_mutf8(infile)))
 
     # now load function
     func = load_obj(infile)
     bypassSerializer = load_obj(infile)
+    stageInputIsPairs = load_obj(infile)
     if bypassSerializer:
         dumps = lambda x: x
     else:
         dumps = dump_pickle
     init_time = time.time()
-    iterator = read_from_pickle_file(infile)
+    if stageInputIsPairs:
+        iterator = read_pairs_from_pickle_file(infile)
+    else:
+        iterator = read_from_pickle_file(infile)
     try:
         for obj in func(split_index, iterator):
             write_with_length(dumps(obj), outfile)


[5/7] git commit: FramedSerializer: _dumps => dumps, _loads => loads.

Posted by ma...@apache.org.
FramedSerializer: _dumps => dumps, _loads => loads.


Project: http://git-wip-us.apache.org/repos/asf/incubator-spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-spark/commit/13122ceb
Tree: http://git-wip-us.apache.org/repos/asf/incubator-spark/tree/13122ceb
Diff: http://git-wip-us.apache.org/repos/asf/incubator-spark/diff/13122ceb

Branch: refs/heads/master
Commit: 13122ceb8c74dc0c4ad37902a3d1b30bf273cc6a
Parents: ffa5bed
Author: Josh Rosen <jo...@apache.org>
Authored: Sun Nov 10 17:48:27 2013 -0800
Committer: Josh Rosen <jo...@apache.org>
Committed: Sun Nov 10 17:53:25 2013 -0800

----------------------------------------------------------------------
 python/pyspark/context.py     |  2 +-
 python/pyspark/rdd.py         |  4 ++--
 python/pyspark/serializers.py | 26 +++++++++++++-------------
 python/pyspark/worker.py      |  4 ++--
 4 files changed, 18 insertions(+), 18 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-spark/blob/13122ceb/python/pyspark/context.py
----------------------------------------------------------------------
diff --git a/python/pyspark/context.py b/python/pyspark/context.py
index 6bb1c6c..cbd41e5 100644
--- a/python/pyspark/context.py
+++ b/python/pyspark/context.py
@@ -251,7 +251,7 @@ class SparkContext(object):
         sent to each cluster only once.
         """
         pickleSer = PickleSerializer()
-        pickled = pickleSer._dumps(value)
+        pickled = pickleSer.dumps(value)
         jbroadcast = self._jsc.broadcast(bytearray(pickled))
         return Broadcast(jbroadcast.id(), value, jbroadcast,
                          self._pickled_broadcast_vars)

http://git-wip-us.apache.org/repos/asf/incubator-spark/blob/13122ceb/python/pyspark/rdd.py
----------------------------------------------------------------------
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
index 062f44f..957f3f8 100644
--- a/python/pyspark/rdd.py
+++ b/python/pyspark/rdd.py
@@ -751,7 +751,7 @@ class RDD(object):
                 buckets[partitionFunc(k) % numPartitions].append((k, v))
             for (split, items) in buckets.iteritems():
                 yield pack_long(split)
-                yield outputSerializer._dumps(items)
+                yield outputSerializer.dumps(items)
         keyed = PipelinedRDD(self, add_shuffle_key)
         keyed._bypass_serializer = True
         pairRDD = self.ctx._jvm.PairwiseRDD(keyed._jrdd.rdd()).asJavaPairRDD()
@@ -970,7 +970,7 @@ class PipelinedRDD(RDD):
         else:
             serializer = self.ctx.serializer
         command = (self.func, self._prev_jrdd_deserializer, serializer)
-        pickled_command = CloudPickleSerializer()._dumps(command)
+        pickled_command = CloudPickleSerializer().dumps(command)
         broadcast_vars = ListConverter().convert(
             [x._jbroadcast for x in self.ctx._pickled_broadcast_vars],
             self.ctx._gateway._gateway_client)

http://git-wip-us.apache.org/repos/asf/incubator-spark/blob/13122ceb/python/pyspark/serializers.py
----------------------------------------------------------------------
diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py
index b23804b..9338df6 100644
--- a/python/pyspark/serializers.py
+++ b/python/pyspark/serializers.py
@@ -125,7 +125,7 @@ class FramedSerializer(Serializer):
                 return
 
     def _write_with_length(self, obj, stream):
-        serialized = self._dumps(obj)
+        serialized = self.dumps(obj)
         write_int(len(serialized), stream)
         stream.write(serialized)
 
@@ -134,16 +134,16 @@ class FramedSerializer(Serializer):
         obj = stream.read(length)
         if obj == "":
             raise EOFError
-        return self._loads(obj)
+        return self.loads(obj)
 
-    def _dumps(self, obj):
+    def dumps(self, obj):
         """
         Serialize an object into a byte array.
         When batching is used, this will be called with an array of objects.
         """
         raise NotImplementedError
 
-    def _loads(self, obj):
+    def loads(self, obj):
         """
         Deserialize an object from a byte array.
         """
@@ -228,8 +228,8 @@ class CartesianDeserializer(FramedSerializer):
 
 class NoOpSerializer(FramedSerializer):
 
-    def _loads(self, obj): return obj
-    def _dumps(self, obj): return obj
+    def loads(self, obj): return obj
+    def dumps(self, obj): return obj
 
 
 class PickleSerializer(FramedSerializer):
@@ -242,12 +242,12 @@ class PickleSerializer(FramedSerializer):
     not be as fast as more specialized serializers.
     """
 
-    def _dumps(self, obj): return cPickle.dumps(obj, 2)
-    _loads = cPickle.loads
+    def dumps(self, obj): return cPickle.dumps(obj, 2)
+    loads = cPickle.loads
 
 class CloudPickleSerializer(PickleSerializer):
 
-    def _dumps(self, obj): return cloudpickle.dumps(obj, 2)
+    def dumps(self, obj): return cloudpickle.dumps(obj, 2)
 
 
 class MarshalSerializer(FramedSerializer):
@@ -259,8 +259,8 @@ class MarshalSerializer(FramedSerializer):
     This serializer is faster than PickleSerializer but supports fewer datatypes.
     """
 
-    _dumps = marshal.dumps
-    _loads = marshal.loads
+    dumps = marshal.dumps
+    loads = marshal.loads
 
 
 class MUTF8Deserializer(Serializer):
@@ -268,14 +268,14 @@ class MUTF8Deserializer(Serializer):
     Deserializes streams written by Java's DataOutputStream.writeUTF().
     """
 
-    def _loads(self, stream):
+    def loads(self, stream):
         length = struct.unpack('>H', stream.read(2))[0]
         return stream.read(length).decode('utf8')
 
     def load_stream(self, stream):
         while True:
             try:
-                yield self._loads(stream)
+                yield self.loads(stream)
             except struct.error:
                 return
             except EOFError:

http://git-wip-us.apache.org/repos/asf/incubator-spark/blob/13122ceb/python/pyspark/worker.py
----------------------------------------------------------------------
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index 2751f12..f2b3f3c 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -51,7 +51,7 @@ def main(infile, outfile):
         return
 
     # fetch name of workdir
-    spark_files_dir = mutf8_deserializer._loads(infile)
+    spark_files_dir = mutf8_deserializer.loads(infile)
     SparkFiles._root_directory = spark_files_dir
     SparkFiles._is_running_on_worker = True
 
@@ -66,7 +66,7 @@ def main(infile, outfile):
     sys.path.append(spark_files_dir) # *.py files that were added will be copied here
     num_python_includes =  read_int(infile)
     for _ in range(num_python_includes):
-        filename = mutf8_deserializer._loads(infile)
+        filename = mutf8_deserializer.loads(infile)
         sys.path.append(os.path.join(spark_files_dir, filename))
 
     command = pickleSer._read_with_length(infile)


[3/7] git commit: Add custom serializer support to PySpark.

Posted by ma...@apache.org.
Add custom serializer support to PySpark.

For now, this only adds MarshalSerializer, but it lays the groundwork
for other supporting custom serializers.  Many of these mechanisms
can also be used to support deserialization of different data formats
sent by Java, such as data encoded by MsgPack.

This also fixes a bug in SparkContext.union().


Project: http://git-wip-us.apache.org/repos/asf/incubator-spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-spark/commit/cbb7f04a
Tree: http://git-wip-us.apache.org/repos/asf/incubator-spark/tree/cbb7f04a
Diff: http://git-wip-us.apache.org/repos/asf/incubator-spark/diff/cbb7f04a

Branch: refs/heads/master
Commit: cbb7f04aef2220ece93dea9f3fa98b5db5f270d6
Parents: 7d68a81
Author: Josh Rosen <jo...@apache.org>
Authored: Tue Nov 5 17:52:39 2013 -0800
Committer: Josh Rosen <jo...@apache.org>
Committed: Sun Nov 10 16:45:38 2013 -0800

----------------------------------------------------------------------
 .../org/apache/spark/api/python/PythonRDD.scala |  23 +-
 python/epydoc.conf                              |   2 +-
 python/pyspark/accumulators.py                  |   6 +-
 python/pyspark/context.py                       |  61 +++-
 python/pyspark/rdd.py                           |  86 ++---
 python/pyspark/serializers.py                   | 310 +++++++++++++++----
 python/pyspark/tests.py                         |   3 +-
 python/pyspark/worker.py                        |  41 ++-
 python/run-tests                                |   1 +
 9 files changed, 363 insertions(+), 170 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-spark/blob/cbb7f04a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
index eb0b0db..ef9bf4d 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
@@ -221,18 +221,6 @@ private[spark] object PythonRDD {
     JavaRDD.fromRDD(sc.sc.parallelize(objs, parallelism))
   }
 
-  def writeStringAsPickle(elem: String, dOut: DataOutputStream) {
-    val s = elem.getBytes("UTF-8")
-    val length = 2 + 1 + 4 + s.length + 1
-    dOut.writeInt(length)
-    dOut.writeByte(Pickle.PROTO)
-    dOut.writeByte(Pickle.TWO)
-    dOut.write(Pickle.BINUNICODE)
-    dOut.writeInt(Integer.reverseBytes(s.length))
-    dOut.write(s)
-    dOut.writeByte(Pickle.STOP)
-  }
-
   def writeToStream(elem: Any, dataOut: DataOutputStream) {
     elem match {
       case bytes: Array[Byte] =>
@@ -244,9 +232,7 @@ private[spark] object PythonRDD {
         dataOut.writeInt(pair._2.length)
         dataOut.write(pair._2)
       case str: String =>
-        // Until we've implemented full custom serializer support, we need to return
-        // strings as Pickles to properly support union() and cartesian():
-        writeStringAsPickle(str, dataOut)
+        dataOut.writeUTF(str)
       case other =>
         throw new SparkException("Unexpected element type " + other.getClass)
     }
@@ -271,13 +257,6 @@ private[spark] object PythonRDD {
   }
 }
 
-private object Pickle {
-  val PROTO: Byte = 0x80.toByte
-  val TWO: Byte = 0x02.toByte
-  val BINUNICODE: Byte = 'X'
-  val STOP: Byte = '.'
-}
-
 private class BytesToString extends org.apache.spark.api.java.function.Function[Array[Byte], String] {
   override def call(arr: Array[Byte]) : String = new String(arr, "UTF-8")
 }

http://git-wip-us.apache.org/repos/asf/incubator-spark/blob/cbb7f04a/python/epydoc.conf
----------------------------------------------------------------------
diff --git a/python/epydoc.conf b/python/epydoc.conf
index 1d0d002..0b42e72 100644
--- a/python/epydoc.conf
+++ b/python/epydoc.conf
@@ -32,6 +32,6 @@ target: docs/
 
 private: no
 
-exclude: pyspark.cloudpickle pyspark.worker pyspark.join pyspark.serializers
+exclude: pyspark.cloudpickle pyspark.worker pyspark.join
          pyspark.java_gateway pyspark.examples pyspark.shell pyspark.test
          pyspark.rddsampler pyspark.daemon

http://git-wip-us.apache.org/repos/asf/incubator-spark/blob/cbb7f04a/python/pyspark/accumulators.py
----------------------------------------------------------------------
diff --git a/python/pyspark/accumulators.py b/python/pyspark/accumulators.py
index da3d966..2204e9c 100644
--- a/python/pyspark/accumulators.py
+++ b/python/pyspark/accumulators.py
@@ -90,9 +90,11 @@ import struct
 import SocketServer
 import threading
 from pyspark.cloudpickle import CloudPickler
-from pyspark.serializers import read_int, read_with_length, load_pickle
+from pyspark.serializers import read_int, PickleSerializer
 
 
+pickleSer = PickleSerializer()
+
 # Holds accumulators registered on the current machine, keyed by ID. This is then used to send
 # the local accumulator updates back to the driver program at the end of a task.
 _accumulatorRegistry = {}
@@ -211,7 +213,7 @@ class _UpdateRequestHandler(SocketServer.StreamRequestHandler):
         from pyspark.accumulators import _accumulatorRegistry
         num_updates = read_int(self.rfile)
         for _ in range(num_updates):
-            (aid, update) = load_pickle(read_with_length(self.rfile))
+            (aid, update) = pickleSer._read_with_length(self.rfile)
             _accumulatorRegistry[aid] += update
         # Write a byte in acknowledgement
         self.wfile.write(struct.pack("!b", 1))

http://git-wip-us.apache.org/repos/asf/incubator-spark/blob/cbb7f04a/python/pyspark/context.py
----------------------------------------------------------------------
diff --git a/python/pyspark/context.py b/python/pyspark/context.py
index 0fec1a6..6bb1c6c 100644
--- a/python/pyspark/context.py
+++ b/python/pyspark/context.py
@@ -26,7 +26,7 @@ from pyspark.accumulators import Accumulator
 from pyspark.broadcast import Broadcast
 from pyspark.files import SparkFiles
 from pyspark.java_gateway import launch_gateway
-from pyspark.serializers import dump_pickle, write_with_length, batched
+from pyspark.serializers import PickleSerializer, BatchedSerializer, MUTF8Deserializer
 from pyspark.storagelevel import StorageLevel
 from pyspark.rdd import RDD
 
@@ -51,7 +51,7 @@ class SparkContext(object):
 
 
     def __init__(self, master, jobName, sparkHome=None, pyFiles=None,
-        environment=None, batchSize=1024):
+        environment=None, batchSize=1024, serializer=PickleSerializer()):
         """
         Create a new SparkContext.
 
@@ -67,6 +67,7 @@ class SparkContext(object):
         @param batchSize: The number of Python objects represented as a single
                Java object.  Set 1 to disable batching or -1 to use an
                unlimited batch size.
+        @param serializer: The serializer for RDDs.
 
 
         >>> from pyspark.context import SparkContext
@@ -83,7 +84,13 @@ class SparkContext(object):
         self.jobName = jobName
         self.sparkHome = sparkHome or None # None becomes null in Py4J
         self.environment = environment or {}
-        self.batchSize = batchSize  # -1 represents a unlimited batch size
+        self._batchSize = batchSize  # -1 represents an unlimited batch size
+        self._unbatched_serializer = serializer
+        if batchSize == 1:
+            self.serializer = self._unbatched_serializer
+        else:
+            self.serializer = BatchedSerializer(self._unbatched_serializer,
+                                                batchSize)
 
         # Create the Java SparkContext through Py4J
         empty_string_array = self._gateway.new_array(self._jvm.String, 0)
@@ -184,15 +191,17 @@ class SparkContext(object):
         # Make sure we distribute data evenly if it's smaller than self.batchSize
         if "__len__" not in dir(c):
             c = list(c)    # Make it a list so we can compute its length
-        batchSize = min(len(c) // numSlices, self.batchSize)
+        batchSize = min(len(c) // numSlices, self._batchSize)
         if batchSize > 1:
-            c = batched(c, batchSize)
-        for x in c:
-            write_with_length(dump_pickle(x), tempFile)
+            serializer = BatchedSerializer(self._unbatched_serializer,
+                                           batchSize)
+        else:
+            serializer = self._unbatched_serializer
+        serializer.dump_stream(c, tempFile)
         tempFile.close()
         readRDDFromFile = self._jvm.PythonRDD.readRDDFromFile
         jrdd = readRDDFromFile(self._jsc, tempFile.name, numSlices)
-        return RDD(jrdd, self)
+        return RDD(jrdd, self, serializer)
 
     def textFile(self, name, minSplits=None):
         """
@@ -201,21 +210,39 @@ class SparkContext(object):
         RDD of Strings.
         """
         minSplits = minSplits or min(self.defaultParallelism, 2)
-        jrdd = self._jsc.textFile(name, minSplits)
-        return RDD(jrdd, self)
+        return RDD(self._jsc.textFile(name, minSplits), self,
+                   MUTF8Deserializer())
 
-    def _checkpointFile(self, name):
+    def _checkpointFile(self, name, input_deserializer):
         jrdd = self._jsc.checkpointFile(name)
-        return RDD(jrdd, self)
+        return RDD(jrdd, self, input_deserializer)
 
     def union(self, rdds):
         """
         Build the union of a list of RDDs.
+
+        This supports unions() of RDDs with different serialized formats,
+        although this forces them to be reserialized using the default
+        serializer:
+
+        >>> path = os.path.join(tempdir, "union-text.txt")
+        >>> with open(path, "w") as testFile:
+        ...    testFile.write("Hello")
+        >>> textFile = sc.textFile(path)
+        >>> textFile.collect()
+        [u'Hello']
+        >>> parallelized = sc.parallelize(["World!"])
+        >>> sorted(sc.union([textFile, parallelized]).collect())
+        [u'Hello', 'World!']
         """
+        first_jrdd_deserializer = rdds[0]._jrdd_deserializer
+        if any(x._jrdd_deserializer != first_jrdd_deserializer for x in rdds):
+            rdds = [x._reserialize() for x in rdds]
         first = rdds[0]._jrdd
         rest = [x._jrdd for x in rdds[1:]]
-        rest = ListConverter().convert(rest, self.gateway._gateway_client)
-        return RDD(self._jsc.union(first, rest), self)
+        rest = ListConverter().convert(rest, self._gateway._gateway_client)
+        return RDD(self._jsc.union(first, rest), self,
+                   rdds[0]._jrdd_deserializer)
 
     def broadcast(self, value):
         """
@@ -223,7 +250,9 @@ class SparkContext(object):
         object for reading it in distributed functions. The variable will be
         sent to each cluster only once.
         """
-        jbroadcast = self._jsc.broadcast(bytearray(dump_pickle(value)))
+        pickleSer = PickleSerializer()
+        pickled = pickleSer._dumps(value)
+        jbroadcast = self._jsc.broadcast(bytearray(pickled))
         return Broadcast(jbroadcast.id(), value, jbroadcast,
                          self._pickled_broadcast_vars)
 
@@ -235,7 +264,7 @@ class SparkContext(object):
         and floating-point numbers if you do not provide one. For other types,
         a custom AccumulatorParam can be used.
         """
-        if accum_param == None:
+        if accum_param is None:
             if isinstance(value, int):
                 accum_param = accumulators.INT_ACCUMULATOR_PARAM
             elif isinstance(value, float):

http://git-wip-us.apache.org/repos/asf/incubator-spark/blob/cbb7f04a/python/pyspark/rdd.py
----------------------------------------------------------------------
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
index d3c4d13..6691c30 100644
--- a/python/pyspark/rdd.py
+++ b/python/pyspark/rdd.py
@@ -18,7 +18,7 @@
 from base64 import standard_b64encode as b64enc
 import copy
 from collections import defaultdict
-from itertools import chain, ifilter, imap, product
+from itertools import chain, ifilter, imap
 import operator
 import os
 import sys
@@ -28,8 +28,8 @@ from tempfile import NamedTemporaryFile
 from threading import Thread
 
 from pyspark import cloudpickle
-from pyspark.serializers import batched, Batch, dump_pickle, load_pickle, \
-    read_from_pickle_file, pack_long
+from pyspark.serializers import NoOpSerializer, CartesianDeserializer, \
+    BatchedSerializer, pack_long
 from pyspark.join import python_join, python_left_outer_join, \
     python_right_outer_join, python_cogroup
 from pyspark.statcounter import StatCounter
@@ -48,13 +48,12 @@ class RDD(object):
     operated on in parallel.
     """
 
-    def __init__(self, jrdd, ctx):
+    def __init__(self, jrdd, ctx, jrdd_deserializer):
         self._jrdd = jrdd
         self.is_cached = False
         self.is_checkpointed = False
         self.ctx = ctx
-        self._partitionFunc = None
-        self._stage_input_is_pairs = False
+        self._jrdd_deserializer = jrdd_deserializer
 
     @property
     def context(self):
@@ -248,7 +247,23 @@ class RDD(object):
         >>> rdd.union(rdd).collect()
         [1, 1, 2, 3, 1, 1, 2, 3]
         """
-        return RDD(self._jrdd.union(other._jrdd), self.ctx)
+        if self._jrdd_deserializer == other._jrdd_deserializer:
+            rdd = RDD(self._jrdd.union(other._jrdd), self.ctx,
+                      self._jrdd_deserializer)
+            return rdd
+        else:
+            # These RDDs contain data in different serialized formats, so we
+            # must normalize them to the default serializer.
+            self_copy = self._reserialize()
+            other_copy = other._reserialize()
+            return RDD(self_copy._jrdd.union(other_copy._jrdd), self.ctx,
+                       self.ctx.serializer)
+
+    def _reserialize(self):
+        if self._jrdd_deserializer == self.ctx.serializer:
+            return self
+        else:
+            return self.map(lambda x: x, preservesPartitioning=True)
 
     def __add__(self, other):
         """
@@ -335,18 +350,9 @@ class RDD(object):
         [(1, 1), (1, 2), (2, 1), (2, 2)]
         """
         # Due to batching, we can't use the Java cartesian method.
-        java_cartesian = RDD(self._jrdd.cartesian(other._jrdd), self.ctx)
-        def unpack_batches(pair):
-            (x, y) = pair
-            if type(x) == Batch or type(y) == Batch:
-                xs = x.items if type(x) == Batch else [x]
-                ys = y.items if type(y) == Batch else [y]
-                for pair in product(xs, ys):
-                    yield pair
-            else:
-                yield pair
-        java_cartesian._stage_input_is_pairs = True
-        return java_cartesian.flatMap(unpack_batches)
+        deserializer = CartesianDeserializer(self._jrdd_deserializer,
+                                             other._jrdd_deserializer)
+        return RDD(self._jrdd.cartesian(other._jrdd), self.ctx, deserializer)
 
     def groupBy(self, f, numPartitions=None):
         """
@@ -405,7 +411,7 @@ class RDD(object):
         self.ctx._writeToFile(iterator, tempFile.name)
         # Read the data into Python and deserialize it:
         with open(tempFile.name, 'rb') as tempFile:
-            for item in read_from_pickle_file(tempFile):
+            for item in self._jrdd_deserializer.load_stream(tempFile):
                 yield item
         os.unlink(tempFile.name)
 
@@ -573,7 +579,7 @@ class RDD(object):
         items = []
         for partition in range(mapped._jrdd.splits().size()):
             iterator = self.ctx._takePartition(mapped._jrdd.rdd(), partition)
-            items.extend(self._collect_iterator_through_file(iterator))
+            items.extend(mapped._collect_iterator_through_file(iterator))
             if len(items) >= num:
                 break
         return items[:num]
@@ -737,6 +743,7 @@ class RDD(object):
         # Transferring O(n) objects to Java is too expensive.  Instead, we'll
         # form the hash buckets in Python, transferring O(numPartitions) objects
         # to Java.  Each object is a (splitNumber, [objects]) pair.
+        outputSerializer = self.ctx._unbatched_serializer
         def add_shuffle_key(split, iterator):
 
             buckets = defaultdict(list)
@@ -745,14 +752,14 @@ class RDD(object):
                 buckets[partitionFunc(k) % numPartitions].append((k, v))
             for (split, items) in buckets.iteritems():
                 yield pack_long(split)
-                yield dump_pickle(Batch(items))
+                yield outputSerializer._dumps(items)
         keyed = PipelinedRDD(self, add_shuffle_key)
         keyed._bypass_serializer = True
         pairRDD = self.ctx._jvm.PairwiseRDD(keyed._jrdd.rdd()).asJavaPairRDD()
         partitioner = self.ctx._jvm.PythonPartitioner(numPartitions,
                                                      id(partitionFunc))
         jrdd = pairRDD.partitionBy(partitioner).values()
-        rdd = RDD(jrdd, self.ctx)
+        rdd = RDD(jrdd, self.ctx, BatchedSerializer(outputSerializer))
         # This is required so that id(partitionFunc) remains unique, even if
         # partitionFunc is a lambda:
         rdd._partitionFunc = partitionFunc
@@ -789,7 +796,8 @@ class RDD(object):
             numPartitions = self.ctx.defaultParallelism
         def combineLocally(iterator):
             combiners = {}
-            for (k, v) in iterator:
+            for x in iterator:
+                (k, v) = x
                 if k not in combiners:
                     combiners[k] = createCombiner(v)
                 else:
@@ -931,38 +939,38 @@ class PipelinedRDD(RDD):
     20
     """
     def __init__(self, prev, func, preservesPartitioning=False):
-        if isinstance(prev, PipelinedRDD) and prev._is_pipelinable():
+        if not isinstance(prev, PipelinedRDD) or not prev._is_pipelinable():
+            # This transformation is the first in its stage:
+            self.func = func
+            self.preservesPartitioning = preservesPartitioning
+            self._prev_jrdd = prev._jrdd
+            self._prev_jrdd_deserializer = prev._jrdd_deserializer
+        else:
             prev_func = prev.func
             def pipeline_func(split, iterator):
                 return func(split, prev_func(split, iterator))
             self.func = pipeline_func
             self.preservesPartitioning = \
                 prev.preservesPartitioning and preservesPartitioning
-            self._prev_jrdd = prev._prev_jrdd
-        else:
-            self.func = func
-            self.preservesPartitioning = preservesPartitioning
-            self._prev_jrdd = prev._jrdd
-        self._stage_input_is_pairs = prev._stage_input_is_pairs
+            self._prev_jrdd = prev._prev_jrdd  # maintain the pipeline
+            self._prev_jrdd_deserializer = prev._prev_jrdd_deserializer
         self.is_cached = False
         self.is_checkpointed = False
         self.ctx = prev.ctx
         self.prev = prev
         self._jrdd_val = None
+        self._jrdd_deserializer = self.ctx.serializer
         self._bypass_serializer = False
 
     @property
     def _jrdd(self):
         if self._jrdd_val:
             return self._jrdd_val
-        func = self.func
-        if not self._bypass_serializer and self.ctx.batchSize != 1:
-            oldfunc = self.func
-            batchSize = self.ctx.batchSize
-            def batched_func(split, iterator):
-                return batched(oldfunc(split, iterator), batchSize)
-            func = batched_func
-        cmds = [func, self._bypass_serializer, self._stage_input_is_pairs]
+        if self._bypass_serializer:
+            serializer = NoOpSerializer()
+        else:
+            serializer = self.ctx.serializer
+        cmds = [self.func, self._prev_jrdd_deserializer, serializer]
         pipe_command = ' '.join(b64enc(cloudpickle.dumps(f)) for f in cmds)
         broadcast_vars = ListConverter().convert(
             [x._jbroadcast for x in self.ctx._pickled_broadcast_vars],

http://git-wip-us.apache.org/repos/asf/incubator-spark/blob/cbb7f04a/python/pyspark/serializers.py
----------------------------------------------------------------------
diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py
index fd02e1e..4fb4444 100644
--- a/python/pyspark/serializers.py
+++ b/python/pyspark/serializers.py
@@ -15,8 +15,58 @@
 # limitations under the License.
 #
 
-import struct
+"""
+PySpark supports custom serializers for transferring data; this can improve
+performance.
+
+By default, PySpark uses L{PickleSerializer} to serialize objects using Python's
+C{cPickle} serializer, which can serialize nearly any Python object.
+Other serializers, like L{MarshalSerializer}, support fewer datatypes but can be
+faster.
+
+The serializer is chosen when creating L{SparkContext}:
+
+>>> from pyspark.context import SparkContext
+>>> from pyspark.serializers import MarshalSerializer
+>>> sc = SparkContext('local', 'test', serializer=MarshalSerializer())
+>>> sc.parallelize(list(range(1000))).map(lambda x: 2 * x).take(10)
+[0, 2, 4, 6, 8, 10, 12, 14, 16, 18]
+>>> sc.stop()
+
+By default, PySpark serialize objects in batches; the batch size can be
+controlled through SparkContext's C{batchSize} parameter
+(the default size is 1024 objects):
+
+>>> sc = SparkContext('local', 'test', batchSize=2)
+>>> rdd = sc.parallelize(range(16), 4).map(lambda x: x)
+
+Behind the scenes, this creates a JavaRDD with four partitions, each of
+which contains two batches of two objects:
+
+>>> rdd.glom().collect()
+[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15]]
+>>> rdd._jrdd.count()
+8L
+>>> sc.stop()
+
+A batch size of -1 uses an unlimited batch size, and a size of 1 disables
+batching:
+
+>>> sc = SparkContext('local', 'test', batchSize=1)
+>>> rdd = sc.parallelize(range(16), 4).map(lambda x: x)
+>>> rdd.glom().collect()
+[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15]]
+>>> rdd._jrdd.count()
+16L
+"""
+
 import cPickle
+from itertools import chain, izip, product
+import marshal
+import struct
+
+
+__all__ = ["PickleSerializer", "MarshalSerializer"]
 
 
 class SpecialLengths(object):
@@ -25,41 +75,206 @@ class SpecialLengths(object):
     TIMING_DATA = -3
 
 
-class Batch(object):
+class Serializer(object):
+
+    def dump_stream(self, iterator, stream):
+        """
+        Serialize an iterator of objects to the output stream.
+        """
+        raise NotImplementedError
+
+    def load_stream(self, stream):
+        """
+        Return an iterator of deserialized objects from the input stream.
+        """
+        raise NotImplementedError
+
+
+    def _load_stream_without_unbatching(self, stream):
+        return self.load_stream(stream)
+
+    # Note: our notion of "equality" is that output generated by
+    # equal serializers can be deserialized using the same serializer.
+
+    # This default implementation handles the simple cases;
+    # subclasses should override __eq__ as appropriate.
+
+    def __eq__(self, other):
+        return isinstance(other, self.__class__)
+
+    def __ne__(self, other):
+        return not self.__eq__(other)
+
+
+class FramedSerializer(Serializer):
+    """
+    Serializer that writes objects as a stream of (length, data) pairs,
+    where C{length} is a 32-bit integer and data is C{length} bytes.
+    """
+
+    def dump_stream(self, iterator, stream):
+        for obj in iterator:
+            self._write_with_length(obj, stream)
+
+    def load_stream(self, stream):
+        while True:
+            try:
+                yield self._read_with_length(stream)
+            except EOFError:
+                return
+
+    def _write_with_length(self, obj, stream):
+        serialized = self._dumps(obj)
+        write_int(len(serialized), stream)
+        stream.write(serialized)
+
+    def _read_with_length(self, stream):
+        length = read_int(stream)
+        obj = stream.read(length)
+        if obj == "":
+            raise EOFError
+        return self._loads(obj)
+
+    def _dumps(self, obj):
+        """
+        Serialize an object into a byte array.
+        When batching is used, this will be called with an array of objects.
+        """
+        raise NotImplementedError
+
+    def _loads(self, obj):
+        """
+        Deserialize an object from a byte array.
+        """
+        raise NotImplementedError
+
+
+class BatchedSerializer(Serializer):
+    """
+    Serializes a stream of objects in batches by calling its wrapped
+    Serializer with streams of objects.
+    """
+
+    UNLIMITED_BATCH_SIZE = -1
+
+    def __init__(self, serializer, batchSize=UNLIMITED_BATCH_SIZE):
+        self.serializer = serializer
+        self.batchSize = batchSize
+
+    def _batched(self, iterator):
+        if self.batchSize == self.UNLIMITED_BATCH_SIZE:
+            yield list(iterator)
+        else:
+            items = []
+            count = 0
+            for item in iterator:
+                items.append(item)
+                count += 1
+                if count == self.batchSize:
+                    yield items
+                    items = []
+                    count = 0
+            if items:
+                yield items
+
+    def dump_stream(self, iterator, stream):
+        if isinstance(iterator, basestring):
+            iterator = [iterator]
+        self.serializer.dump_stream(self._batched(iterator), stream)
+
+    def load_stream(self, stream):
+        return chain.from_iterable(self._load_stream_without_unbatching(stream))
+
+    def _load_stream_without_unbatching(self, stream):
+            return self.serializer.load_stream(stream)
+
+    def __eq__(self, other):
+        return isinstance(other, BatchedSerializer) and \
+               other.serializer == self.serializer
+
+    def __str__(self):
+        return "BatchedSerializer<%s>" % str(self.serializer)
+
+
+class CartesianDeserializer(FramedSerializer):
     """
-    Used to store multiple RDD entries as a single Java object.
+    Deserializes the JavaRDD cartesian() of two PythonRDDs.
+    """
+
+    def __init__(self, key_ser, val_ser):
+        self.key_ser = key_ser
+        self.val_ser = val_ser
+
+    def load_stream(self, stream):
+        key_stream = self.key_ser._load_stream_without_unbatching(stream)
+        val_stream = self.val_ser._load_stream_without_unbatching(stream)
+        key_is_batched = isinstance(self.key_ser, BatchedSerializer)
+        val_is_batched = isinstance(self.val_ser, BatchedSerializer)
+        for (keys, vals) in izip(key_stream, val_stream):
+            keys = keys if key_is_batched else [keys]
+            vals = vals if val_is_batched else [vals]
+            for pair in product(keys, vals):
+                yield pair
+
+    def __eq__(self, other):
+        return isinstance(other, CartesianDeserializer) and \
+               self.key_ser == other.key_ser and self.val_ser == other.val_ser
+
+    def __str__(self):
+        return "CartesianDeserializer<%s, %s>" % \
+               (str(self.key_ser), str(self.val_ser))
+
+
+class NoOpSerializer(FramedSerializer):
+
+    def _loads(self, obj): return obj
+    def _dumps(self, obj): return obj
+
+
+class PickleSerializer(FramedSerializer):
+    """
+    Serializes objects using Python's cPickle serializer:
+
+        http://docs.python.org/2/library/pickle.html
+
+    This serializer supports nearly any Python object, but may
+    not be as fast as more specialized serializers.
+    """
+
+    def _dumps(self, obj): return cPickle.dumps(obj, 2)
+    _loads = cPickle.loads
+
 
-    This relieves us from having to explicitly track whether an RDD
-    is stored as batches of objects and avoids problems when processing
-    the union() of batched and unbatched RDDs (e.g. the union() of textFile()
-    with another RDD).
+class MarshalSerializer(FramedSerializer):
     """
-    def __init__(self, items):
-        self.items = items
+    Serializes objects using Python's Marshal serializer:
 
+        http://docs.python.org/2/library/marshal.html
 
-def batched(iterator, batchSize):
-    if batchSize == -1: # unlimited batch size
-        yield Batch(list(iterator))
-    else:
-        items = []
-        count = 0
-        for item in iterator:
-            items.append(item)
-            count += 1
-            if count == batchSize:
-                yield Batch(items)
-                items = []
-                count = 0
-        if items:
-            yield Batch(items)
+    This serializer is faster than PickleSerializer but supports fewer datatypes.
+    """
+
+    _dumps = marshal.dumps
+    _loads = marshal.loads
 
 
-def dump_pickle(obj):
-    return cPickle.dumps(obj, 2)
+class MUTF8Deserializer(Serializer):
+    """
+    Deserializes streams written by Java's DataOutputStream.writeUTF().
+    """
 
+    def _loads(self, stream):
+        length = struct.unpack('>H', stream.read(2))[0]
+        return stream.read(length).decode('utf8')
 
-load_pickle = cPickle.loads
+    def load_stream(self, stream):
+        while True:
+            try:
+                yield self._loads(stream)
+            except struct.error:
+                return
+            except EOFError:
+                return
 
 
 def read_long(stream):
@@ -90,43 +305,4 @@ def write_int(value, stream):
 
 def write_with_length(obj, stream):
     write_int(len(obj), stream)
-    stream.write(obj)
-
-
-def read_mutf8(stream):
-    """
-    Read a string written with Java's DataOutputStream.writeUTF() method.
-    """
-    length = struct.unpack('>H', stream.read(2))[0]
-    return stream.read(length).decode('utf8')
-
-
-def read_with_length(stream):
-    length = read_int(stream)
-    obj = stream.read(length)
-    if obj == "":
-        raise EOFError
-    return obj
-
-
-def read_from_pickle_file(stream):
-    try:
-        while True:
-            obj = load_pickle(read_with_length(stream))
-            if type(obj) == Batch:  # We don't care about inheritance
-                for item in obj.items:
-                    yield item
-            else:
-                yield obj
-    except EOFError:
-        return
-
-
-def read_pairs_from_pickle_file(stream):
-    try:
-        while True:
-            a = load_pickle(read_with_length(stream))
-            b = load_pickle(read_with_length(stream))
-            yield (a, b)
-    except EOFError:
-        return
\ No newline at end of file
+    stream.write(obj)
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/incubator-spark/blob/cbb7f04a/python/pyspark/tests.py
----------------------------------------------------------------------
diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py
index 29d6a12..621e1cb 100644
--- a/python/pyspark/tests.py
+++ b/python/pyspark/tests.py
@@ -86,7 +86,8 @@ class TestCheckpoint(PySparkTestCase):
         time.sleep(1)  # 1 second
 
         self.assertTrue(flatMappedRDD.getCheckpointFile() is not None)
-        recovered = self.sc._checkpointFile(flatMappedRDD.getCheckpointFile())
+        recovered = self.sc._checkpointFile(flatMappedRDD.getCheckpointFile(),
+                                            flatMappedRDD._jrdd_deserializer)
         self.assertEquals([1, 2, 3, 4], recovered.collect())
 
 

http://git-wip-us.apache.org/repos/asf/incubator-spark/blob/cbb7f04a/python/pyspark/worker.py
----------------------------------------------------------------------
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index 4e64557..5b16d5d 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -30,13 +30,17 @@ from pyspark.accumulators import _accumulatorRegistry
 from pyspark.broadcast import Broadcast, _broadcastRegistry
 from pyspark.cloudpickle import CloudPickler
 from pyspark.files import SparkFiles
-from pyspark.serializers import write_with_length, read_with_length, write_int, \
-    read_long, write_long, read_int, dump_pickle, load_pickle, read_from_pickle_file, \
-    SpecialLengths, read_mutf8, read_pairs_from_pickle_file
+from pyspark.serializers import write_with_length, write_int, read_long, \
+    write_long, read_int, SpecialLengths, MUTF8Deserializer, PickleSerializer
+
+
+pickleSer = PickleSerializer()
+mutf8_deserializer = MUTF8Deserializer()
 
 
 def load_obj(infile):
-    return load_pickle(standard_b64decode(infile.readline().strip()))
+    decoded = standard_b64decode(infile.readline().strip())
+    return pickleSer._loads(decoded)
 
 
 def report_times(outfile, boot, init, finish):
@@ -53,7 +57,7 @@ def main(infile, outfile):
         return
 
     # fetch name of workdir
-    spark_files_dir = read_mutf8(infile)
+    spark_files_dir = mutf8_deserializer._loads(infile)
     SparkFiles._root_directory = spark_files_dir
     SparkFiles._is_running_on_worker = True
 
@@ -61,31 +65,24 @@ def main(infile, outfile):
     num_broadcast_variables = read_int(infile)
     for _ in range(num_broadcast_variables):
         bid = read_long(infile)
-        value = read_with_length(infile)
-        _broadcastRegistry[bid] = Broadcast(bid, load_pickle(value))
+        value = pickleSer._read_with_length(infile)
+        _broadcastRegistry[bid] = Broadcast(bid, value)
 
     # fetch names of includes (*.zip and *.egg files) and construct PYTHONPATH
     sys.path.append(spark_files_dir) # *.py files that were added will be copied here
     num_python_includes =  read_int(infile)
     for _ in range(num_python_includes):
-        sys.path.append(os.path.join(spark_files_dir, read_mutf8(infile)))
+        filename = mutf8_deserializer._loads(infile)
+        sys.path.append(os.path.join(spark_files_dir, filename))
 
-    # now load function
+    # Load this stage's function and serializer:
     func = load_obj(infile)
-    bypassSerializer = load_obj(infile)
-    stageInputIsPairs = load_obj(infile)
-    if bypassSerializer:
-        dumps = lambda x: x
-    else:
-        dumps = dump_pickle
+    deserializer = load_obj(infile)
+    serializer = load_obj(infile)
     init_time = time.time()
-    if stageInputIsPairs:
-        iterator = read_pairs_from_pickle_file(infile)
-    else:
-        iterator = read_from_pickle_file(infile)
     try:
-        for obj in func(split_index, iterator):
-            write_with_length(dumps(obj), outfile)
+        iterator = deserializer.load_stream(infile)
+        serializer.dump_stream(func(split_index, iterator), outfile)
     except Exception as e:
         write_int(SpecialLengths.PYTHON_EXCEPTION_THROWN, outfile)
         write_with_length(traceback.format_exc(), outfile)
@@ -96,7 +93,7 @@ def main(infile, outfile):
     write_int(SpecialLengths.END_OF_DATA_SECTION, outfile)
     write_int(len(_accumulatorRegistry), outfile)
     for (aid, accum) in _accumulatorRegistry.items():
-        write_with_length(dump_pickle((aid, accum._value)), outfile)
+        pickleSer._write_with_length((aid, accum._value), outfile)
 
 
 if __name__ == '__main__':

http://git-wip-us.apache.org/repos/asf/incubator-spark/blob/cbb7f04a/python/run-tests
----------------------------------------------------------------------
diff --git a/python/run-tests b/python/run-tests
index cbc554e..d4dad67 100755
--- a/python/run-tests
+++ b/python/run-tests
@@ -37,6 +37,7 @@ run_test "pyspark/rdd.py"
 run_test "pyspark/context.py"
 run_test "-m doctest pyspark/broadcast.py"
 run_test "-m doctest pyspark/accumulators.py"
+run_test "-m doctest pyspark/serializers.py"
 run_test "pyspark/tests.py"
 
 if [[ $FAILED != 0 ]]; then


[7/7] git commit: Merge pull request #146 from JoshRosen/pyspark-custom-serializers

Posted by ma...@apache.org.
Merge pull request #146 from JoshRosen/pyspark-custom-serializers

Custom Serializers for PySpark

This pull request adds support for custom serializers to PySpark.  For now, all Python-transformed (or parallelize()d RDDs) are serialized with the same serializer that's specified when creating SparkContext.

For now, PySpark includes `PickleSerDe` and `MarshalSerDe` classes for using Python's `pickle` and `marshal` serializers.  It's pretty easy to add support for other serializers, although I still need to add instructions on this.

A few notable changes:

- The Scala `PythonRDD` class no longer manipulates Pickled objects; data from `textFile` is written to Python as MUTF-8 strings.  The Python code performs the appropriate bookkeeping to track which deserializer should be used when reading an underlying JavaRDD.  This mechanism could also be used to support other data exchange formats, such as MsgPack.
- Several magic numbers were refactored into constants.
- Batching is implemented by wrapping / decorating an unbatched SerDe.


Project: http://git-wip-us.apache.org/repos/asf/incubator-spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-spark/commit/fb6875dd
Tree: http://git-wip-us.apache.org/repos/asf/incubator-spark/tree/fb6875dd
Diff: http://git-wip-us.apache.org/repos/asf/incubator-spark/diff/fb6875dd

Branch: refs/heads/master
Commit: fb6875dd5c9334802580155464cef9ac4d4cc1f0
Parents: 330ada1 1b74a27
Author: Matei Zaharia <ma...@eecs.berkeley.edu>
Authored: Tue Nov 26 20:55:40 2013 -0800
Committer: Matei Zaharia <ma...@eecs.berkeley.edu>
Committed: Tue Nov 26 20:55:40 2013 -0800

----------------------------------------------------------------------
 .../org/apache/spark/api/python/PythonRDD.scala | 149 +++------
 python/epydoc.conf                              |   2 +-
 python/pyspark/accumulators.py                  |   6 +-
 python/pyspark/context.py                       |  71 +++--
 python/pyspark/rdd.py                           |  97 +++---
 python/pyspark/serializers.py                   | 301 ++++++++++++++++---
 python/pyspark/tests.py                         |   3 +-
 python/pyspark/worker.py                        |  44 ++-
 python/run-tests                                |   1 +
 9 files changed, 428 insertions(+), 246 deletions(-)
----------------------------------------------------------------------



[6/7] git commit: Removed unused basestring case from dump_stream.

Posted by ma...@apache.org.
Removed unused basestring case from dump_stream.


Project: http://git-wip-us.apache.org/repos/asf/incubator-spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-spark/commit/1b74a27d
Tree: http://git-wip-us.apache.org/repos/asf/incubator-spark/tree/1b74a27d
Diff: http://git-wip-us.apache.org/repos/asf/incubator-spark/diff/1b74a27d

Branch: refs/heads/master
Commit: 1b74a27da026aba7dbe2088ee64974d772feb23d
Parents: 13122ce
Author: Josh Rosen <jo...@apache.org>
Authored: Tue Nov 26 14:35:12 2013 -0800
Committer: Josh Rosen <jo...@apache.org>
Committed: Tue Nov 26 14:35:12 2013 -0800

----------------------------------------------------------------------
 python/pyspark/serializers.py | 2 --
 1 file changed, 2 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-spark/blob/1b74a27d/python/pyspark/serializers.py
----------------------------------------------------------------------
diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py
index 9338df6..811fa6f 100644
--- a/python/pyspark/serializers.py
+++ b/python/pyspark/serializers.py
@@ -179,8 +179,6 @@ class BatchedSerializer(Serializer):
                 yield items
 
     def dump_stream(self, iterator, stream):
-        if isinstance(iterator, basestring):
-            iterator = [iterator]
         self.serializer.dump_stream(self._batched(iterator), stream)
 
     def load_stream(self, stream):


[4/7] git commit: Send PySpark commands as bytes insetad of strings.

Posted by ma...@apache.org.
Send PySpark commands as bytes insetad of strings.


Project: http://git-wip-us.apache.org/repos/asf/incubator-spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-spark/commit/ffa5bedf
Tree: http://git-wip-us.apache.org/repos/asf/incubator-spark/tree/ffa5bedf
Diff: http://git-wip-us.apache.org/repos/asf/incubator-spark/diff/ffa5bedf

Branch: refs/heads/master
Commit: ffa5bedf46fbc89ad5c5658f3b423dfff49b70f0
Parents: cbb7f04
Author: Josh Rosen <jo...@apache.org>
Authored: Sun Nov 10 12:58:28 2013 -0800
Committer: Josh Rosen <jo...@apache.org>
Committed: Sun Nov 10 16:46:00 2013 -0800

----------------------------------------------------------------------
 .../org/apache/spark/api/python/PythonRDD.scala | 24 ++++----------------
 python/pyspark/rdd.py                           | 12 +++++-----
 python/pyspark/serializers.py                   |  5 ++++
 python/pyspark/worker.py                        | 12 ++--------
 4 files changed, 17 insertions(+), 36 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-spark/blob/ffa5bedf/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
index ef9bf4d..132e4fb 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
@@ -27,13 +27,12 @@ import org.apache.spark.api.java.{JavaSparkContext, JavaPairRDD, JavaRDD}
 import org.apache.spark.broadcast.Broadcast
 import org.apache.spark._
 import org.apache.spark.rdd.RDD
-import org.apache.spark.rdd.PipedRDD
 import org.apache.spark.util.Utils
 
 
 private[spark] class PythonRDD[T: ClassManifest](
     parent: RDD[T],
-    command: Seq[String],
+    command: Array[Byte],
     envVars: JMap[String, String],
     pythonIncludes: JList[String],
     preservePartitoning: Boolean,
@@ -44,21 +43,10 @@ private[spark] class PythonRDD[T: ClassManifest](
 
   val bufferSize = System.getProperty("spark.buffer.size", "65536").toInt
 
-  // Similar to Runtime.exec(), if we are given a single string, split it into words
-  // using a standard StringTokenizer (i.e. by spaces)
-  def this(parent: RDD[T], command: String, envVars: JMap[String, String],
-      pythonIncludes: JList[String],
-      preservePartitoning: Boolean, pythonExec: String,
-      broadcastVars: JList[Broadcast[Array[Byte]]],
-      accumulator: Accumulator[JList[Array[Byte]]]) =
-    this(parent, PipedRDD.tokenize(command), envVars, pythonIncludes, preservePartitoning, pythonExec,
-      broadcastVars, accumulator)
-
   override def getPartitions = parent.partitions
 
   override val partitioner = if (preservePartitoning) parent.partitioner else None
 
-
   override def compute(split: Partition, context: TaskContext): Iterator[Array[Byte]] = {
     val startTime = System.currentTimeMillis
     val env = SparkEnv.get
@@ -71,7 +59,6 @@ private[spark] class PythonRDD[T: ClassManifest](
           SparkEnv.set(env)
           val stream = new BufferedOutputStream(worker.getOutputStream, bufferSize)
           val dataOut = new DataOutputStream(stream)
-          val printOut = new PrintWriter(stream)
           // Partition index
           dataOut.writeInt(split.index)
           // sparkFilesDir
@@ -87,17 +74,14 @@ private[spark] class PythonRDD[T: ClassManifest](
           dataOut.writeInt(pythonIncludes.length)
           pythonIncludes.foreach(dataOut.writeUTF)
           dataOut.flush()
-          // Serialized user code
-          for (elem <- command) {
-            printOut.println(elem)
-          }
-          printOut.flush()
+          // Serialized command:
+          dataOut.writeInt(command.length)
+          dataOut.write(command)
           // Data values
           for (elem <- parent.iterator(split, context)) {
             PythonRDD.writeToStream(elem, dataOut)
           }
           dataOut.flush()
-          printOut.flush()
           worker.shutdownOutput()
         } catch {
           case e: IOException =>

http://git-wip-us.apache.org/repos/asf/incubator-spark/blob/ffa5bedf/python/pyspark/rdd.py
----------------------------------------------------------------------
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
index 6691c30..062f44f 100644
--- a/python/pyspark/rdd.py
+++ b/python/pyspark/rdd.py
@@ -27,9 +27,8 @@ from subprocess import Popen, PIPE
 from tempfile import NamedTemporaryFile
 from threading import Thread
 
-from pyspark import cloudpickle
 from pyspark.serializers import NoOpSerializer, CartesianDeserializer, \
-    BatchedSerializer, pack_long
+    BatchedSerializer, CloudPickleSerializer, pack_long
 from pyspark.join import python_join, python_left_outer_join, \
     python_right_outer_join, python_cogroup
 from pyspark.statcounter import StatCounter
@@ -970,8 +969,8 @@ class PipelinedRDD(RDD):
             serializer = NoOpSerializer()
         else:
             serializer = self.ctx.serializer
-        cmds = [self.func, self._prev_jrdd_deserializer, serializer]
-        pipe_command = ' '.join(b64enc(cloudpickle.dumps(f)) for f in cmds)
+        command = (self.func, self._prev_jrdd_deserializer, serializer)
+        pickled_command = CloudPickleSerializer()._dumps(command)
         broadcast_vars = ListConverter().convert(
             [x._jbroadcast for x in self.ctx._pickled_broadcast_vars],
             self.ctx._gateway._gateway_client)
@@ -982,8 +981,9 @@ class PipelinedRDD(RDD):
         includes = ListConverter().convert(self.ctx._python_includes,
                                      self.ctx._gateway._gateway_client)
         python_rdd = self.ctx._jvm.PythonRDD(self._prev_jrdd.rdd(),
-            pipe_command, env, includes, self.preservesPartitioning, self.ctx.pythonExec,
-            broadcast_vars, self.ctx._javaAccumulator, class_manifest)
+            bytearray(pickled_command), env, includes, self.preservesPartitioning,
+            self.ctx.pythonExec, broadcast_vars, self.ctx._javaAccumulator,
+            class_manifest)
         self._jrdd_val = python_rdd.asJavaRDD()
         return self._jrdd_val
 

http://git-wip-us.apache.org/repos/asf/incubator-spark/blob/ffa5bedf/python/pyspark/serializers.py
----------------------------------------------------------------------
diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py
index 4fb4444..b23804b 100644
--- a/python/pyspark/serializers.py
+++ b/python/pyspark/serializers.py
@@ -64,6 +64,7 @@ import cPickle
 from itertools import chain, izip, product
 import marshal
 import struct
+from pyspark import cloudpickle
 
 
 __all__ = ["PickleSerializer", "MarshalSerializer"]
@@ -244,6 +245,10 @@ class PickleSerializer(FramedSerializer):
     def _dumps(self, obj): return cPickle.dumps(obj, 2)
     _loads = cPickle.loads
 
+class CloudPickleSerializer(PickleSerializer):
+
+    def _dumps(self, obj): return cloudpickle.dumps(obj, 2)
+
 
 class MarshalSerializer(FramedSerializer):
     """

http://git-wip-us.apache.org/repos/asf/incubator-spark/blob/ffa5bedf/python/pyspark/worker.py
----------------------------------------------------------------------
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index 5b16d5d..2751f12 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -23,7 +23,6 @@ import sys
 import time
 import socket
 import traceback
-from base64 import standard_b64decode
 # CloudPickler needs to be imported so that depicklers are registered using the
 # copy_reg module.
 from pyspark.accumulators import _accumulatorRegistry
@@ -38,11 +37,6 @@ pickleSer = PickleSerializer()
 mutf8_deserializer = MUTF8Deserializer()
 
 
-def load_obj(infile):
-    decoded = standard_b64decode(infile.readline().strip())
-    return pickleSer._loads(decoded)
-
-
 def report_times(outfile, boot, init, finish):
     write_int(SpecialLengths.TIMING_DATA, outfile)
     write_long(1000 * boot, outfile)
@@ -75,10 +69,8 @@ def main(infile, outfile):
         filename = mutf8_deserializer._loads(infile)
         sys.path.append(os.path.join(spark_files_dir, filename))
 
-    # Load this stage's function and serializer:
-    func = load_obj(infile)
-    deserializer = load_obj(infile)
-    serializer = load_obj(infile)
+    command = pickleSer._read_with_length(infile)
+    (func, deserializer, serializer) = command
     init_time = time.time()
     try:
         iterator = deserializer.load_stream(infile)