You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ma...@apache.org on 2014/09/16 20:45:41 UTC

git commit: [SPARK-2314][SQL] Override collect and take in python library, and count in java library, with optimized versions.

Repository: spark
Updated Branches:
  refs/heads/master 30f288ae3 -> 8e7ae477b


[SPARK-2314][SQL] Override collect and take in python library, and count in java library, with optimized versions.

SchemaRDD overrides RDD functions, including collect, count, and take, with optimized versions making use of the query optimizer.  The java and python interface classes wrapping SchemaRDD need to ensure the optimized versions are called as well.  This patch overrides relevant calls in the python and java interfaces with optimized versions.

Adds a new Row serialization pathway between python and java, based on JList[Array[Byte]] versus the existing RDD[Array[Byte]]. I wasn’t overjoyed about doing this, but I noticed that some QueryPlans implement optimizations in executeCollect(), which outputs an Array[Row] rather than the typical RDD[Row] that can be shipped to python using the existing serialization code. To me it made sense to ship the Array[Row] over to python directly instead of converting it back to an RDD[Row] just for the purpose of sending the Rows to python using the existing serialization code.

Author: Aaron Staple <aa...@gmail.com>

Closes #1592 from staple/SPARK-2314 and squashes the following commits:

89ff550 [Aaron Staple] Merge with master.
6bb7b6c [Aaron Staple] Fix typo.
b56d0ac [Aaron Staple] [SPARK-2314][SQL] Override count in JavaSchemaRDD, forwarding to SchemaRDD's count.
0fc9d40 [Aaron Staple] Fix comment typos.
f03cdfa [Aaron Staple] [SPARK-2314][SQL] Override collect and take in sql.py, forwarding to SchemaRDD's collect.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/8e7ae477
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/8e7ae477
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/8e7ae477

Branch: refs/heads/master
Commit: 8e7ae477ba40a064d27cf149aa211ff6108fe239
Parents: 30f288a
Author: Aaron Staple <aa...@gmail.com>
Authored: Tue Sep 16 11:45:35 2014 -0700
Committer: Michael Armbrust <mi...@databricks.com>
Committed: Tue Sep 16 11:45:35 2014 -0700

----------------------------------------------------------------------
 .../org/apache/spark/api/python/PythonRDD.scala |  2 +-
 python/pyspark/sql.py                           | 47 +++++++++++++++++---
 .../scala/org/apache/spark/sql/SchemaRDD.scala  | 37 ++++++++++-----
 .../spark/sql/api/java/JavaSchemaRDD.scala      |  2 +
 4 files changed, 71 insertions(+), 17 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/8e7ae477/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
index d5002fa..12b345a 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
@@ -776,7 +776,7 @@ private[spark] object PythonRDD extends Logging {
   }
 
   /**
-   * Convert and RDD of Java objects to and RDD of serialized Python objects, that is usable by
+   * Convert an RDD of Java objects to an RDD of serialized Python objects, that is usable by
    * PySpark.
    */
   def javaToPython(jRDD: JavaRDD[Any]): JavaRDD[Array[Byte]] = {

http://git-wip-us.apache.org/repos/asf/spark/blob/8e7ae477/python/pyspark/sql.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py
index eac55cb..621a556 100644
--- a/python/pyspark/sql.py
+++ b/python/pyspark/sql.py
@@ -30,6 +30,7 @@ from operator import itemgetter
 from pyspark.rdd import RDD, PipelinedRDD
 from pyspark.serializers import BatchedSerializer, PickleSerializer, CloudPickleSerializer
 from pyspark.storagelevel import StorageLevel
+from pyspark.traceback_utils import SCCallSiteSync
 
 from itertools import chain, ifilter, imap
 
@@ -1550,6 +1551,18 @@ class SchemaRDD(RDD):
             self._id = self._jrdd.id()
         return self._id
 
+    def limit(self, num):
+        """Limit the result count to the number specified.
+
+        >>> srdd = sqlCtx.inferSchema(rdd)
+        >>> srdd.limit(2).collect()
+        [Row(field1=1, field2=u'row1'), Row(field1=2, field2=u'row2')]
+        >>> srdd.limit(0).collect()
+        []
+        """
+        rdd = self._jschema_rdd.baseSchemaRDD().limit(num).toJavaSchemaRDD()
+        return SchemaRDD(rdd, self.sql_ctx)
+
     def saveAsParquetFile(self, path):
         """Save the contents as a Parquet file, preserving the schema.
 
@@ -1626,15 +1639,39 @@ class SchemaRDD(RDD):
         return self._jschema_rdd.count()
 
     def collect(self):
-        """
-        Return a list that contains all of the rows in this RDD.
+        """Return a list that contains all of the rows in this RDD.
 
-        Each object in the list is on Row, the fields can be accessed as
+        Each object in the list is a Row, the fields can be accessed as
         attributes.
+
+        Unlike the base RDD implementation of collect, this implementation
+        leverages the query optimizer to perform a collect on the SchemaRDD,
+        which supports features such as filter pushdown.
+
+        >>> srdd = sqlCtx.inferSchema(rdd)
+        >>> srdd.collect()
+        [Row(field1=1, field2=u'row1'), ..., Row(field1=3, field2=u'row3')]
         """
-        rows = RDD.collect(self)
+        with SCCallSiteSync(self.context) as css:
+            bytesInJava = self._jschema_rdd.baseSchemaRDD().collectToPython().iterator()
         cls = _create_cls(self.schema())
-        return map(cls, rows)
+        return map(cls, self._collect_iterator_through_file(bytesInJava))
+
+    def take(self, num):
+        """Take the first num rows of the RDD.
+
+        Each object in the list is a Row, the fields can be accessed as
+        attributes.
+
+        Unlike the base RDD implementation of take, this implementation
+        leverages the query optimizer to perform a collect on a SchemaRDD,
+        which supports features such as filter pushdown.
+
+        >>> srdd = sqlCtx.inferSchema(rdd)
+        >>> srdd.take(2)
+        [Row(field1=1, field2=u'row1'), Row(field1=2, field2=u'row2')]
+        """
+        return self.limit(num).collect()
 
     # Convert each object in the RDD to a Row with the right class
     # for this SchemaRDD, so that fields can be accessed as attributes.

http://git-wip-us.apache.org/repos/asf/spark/blob/8e7ae477/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala
index d2ceb4a..3bc5dce 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala
@@ -377,15 +377,15 @@ class SchemaRDD(
   def toJavaSchemaRDD: JavaSchemaRDD = new JavaSchemaRDD(sqlContext, logicalPlan)
 
   /**
-   * Converts a JavaRDD to a PythonRDD. It is used by pyspark.
+   * Helper for converting a Row to a simple Array suitable for pyspark serialization.
    */
-  private[sql] def javaToPython: JavaRDD[Array[Byte]] = {
+  private def rowToJArray(row: Row, structType: StructType): Array[Any] = {
     import scala.collection.Map
 
     def toJava(obj: Any, dataType: DataType): Any = (obj, dataType) match {
       case (null, _) => null
 
-      case (obj: Row, struct: StructType) => rowToArray(obj, struct)
+      case (obj: Row, struct: StructType) => rowToJArray(obj, struct)
 
       case (seq: Seq[Any], array: ArrayType) =>
         seq.map(x => toJava(x, array.elementType)).asJava
@@ -402,23 +402,38 @@ class SchemaRDD(
       case (other, _) => other
     }
 
-    def rowToArray(row: Row, structType: StructType): Array[Any] = {
-      val fields = structType.fields.map(field => field.dataType)
-      row.zip(fields).map {
-        case (obj, dataType) => toJava(obj, dataType)
-      }.toArray
-    }
+    val fields = structType.fields.map(field => field.dataType)
+    row.zip(fields).map {
+      case (obj, dataType) => toJava(obj, dataType)
+    }.toArray
+  }
 
+  /**
+   * Converts a JavaRDD to a PythonRDD. It is used by pyspark.
+   */
+  private[sql] def javaToPython: JavaRDD[Array[Byte]] = {
     val rowSchema = StructType.fromAttributes(this.queryExecution.analyzed.output)
     this.mapPartitions { iter =>
       val pickle = new Pickler
       iter.map { row =>
-        rowToArray(row, rowSchema)
+        rowToJArray(row, rowSchema)
       }.grouped(100).map(batched => pickle.dumps(batched.toArray))
     }
   }
 
   /**
+   * Serializes the Array[Row] returned by SchemaRDD's optimized collect(), using the same
+   * format as javaToPython. It is used by pyspark.
+   */
+  private[sql] def collectToPython: JList[Array[Byte]] = {
+    val rowSchema = StructType.fromAttributes(this.queryExecution.analyzed.output)
+    val pickle = new Pickler
+    new java.util.ArrayList(collect().map { row =>
+      rowToJArray(row, rowSchema)
+    }.grouped(100).map(batched => pickle.dumps(batched.toArray)).toIterable)
+  }
+
+  /**
    * Creates SchemaRDD by applying own schema to derived RDD. Typically used to wrap return value
    * of base RDD functions that do not change schema.
    *
@@ -433,7 +448,7 @@ class SchemaRDD(
   }
 
   // =======================================================================
-  // Overriden RDD actions
+  // Overridden RDD actions
   // =======================================================================
 
   override def collect(): Array[Row] = queryExecution.executedPlan.executeCollect()

http://git-wip-us.apache.org/repos/asf/spark/blob/8e7ae477/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSchemaRDD.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSchemaRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSchemaRDD.scala
index 4d799b4..e7faba0 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSchemaRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSchemaRDD.scala
@@ -112,6 +112,8 @@ class JavaSchemaRDD(
     new java.util.ArrayList(arr)
   }
 
+  override def count(): Long = baseSchemaRDD.count
+
   override def take(num: Int): JList[Row] = {
     import scala.collection.JavaConversions._
     val arr: java.util.Collection[Row] = baseSchemaRDD.take(num).toSeq.map(new Row(_))


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org