You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by jo...@apache.org on 2014/12/24 02:46:33 UTC

spark git commit: [SPARK-4860][pyspark][sql] speeding up `sample()` and `takeSample()`

Repository: spark
Updated Branches:
  refs/heads/master 7e2deb71c -> fd41eb957


[SPARK-4860][pyspark][sql] speeding up `sample()` and `takeSample()`

This PR modifies the python `SchemaRDD` to use `sample()` and `takeSample()` from Scala instead of the slower python implementations from `rdd.py`. This is worthwhile because the `Row`'s are already serialized as Java objects.

In order to use the faster `takeSample()`, a `takeSampleToPython()` method was implemented in `SchemaRDD.scala` following the pattern of `collectToPython()`.

Author: jbencook <jb...@gmail.com>
Author: J. Benjamin Cook <jb...@gmail.com>

Closes #3764 from jbencook/master and squashes the following commits:

6fbc769 [J. Benjamin Cook] [SPARK-4860][pyspark][sql] fixing sloppy indentation for takeSampleToPython() arguments
5170da2 [J. Benjamin Cook] [SPARK-4860][pyspark][sql] fixing typo: from RDD to SchemaRDD
de22f70 [jbencook] [SPARK-4860][pyspark][sql] using sample() method from JavaSchemaRDD
b916442 [jbencook] [SPARK-4860][pyspark][sql] adding sample() to JavaSchemaRDD
020cbdf [jbencook] [SPARK-4860][pyspark][sql] using Scala implementations of `sample()` and `takeSample()`


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/fd41eb95
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/fd41eb95
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/fd41eb95

Branch: refs/heads/master
Commit: fd41eb9574280b5cfee9b94b4f92e4c44363fb14
Parents: 7e2deb7
Author: jbencook <jb...@gmail.com>
Authored: Tue Dec 23 17:46:24 2014 -0800
Committer: Josh Rosen <jo...@databricks.com>
Committed: Tue Dec 23 17:46:24 2014 -0800

----------------------------------------------------------------------
 python/pyspark/sql.py                           | 28 ++++++++++++++++++++
 .../scala/org/apache/spark/sql/SchemaRDD.scala  | 15 +++++++++++
 .../spark/sql/api/java/JavaSchemaRDD.scala      |  6 +++++
 3 files changed, 49 insertions(+)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/fd41eb95/python/pyspark/sql.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py
index 469f824..9807a84 100644
--- a/python/pyspark/sql.py
+++ b/python/pyspark/sql.py
@@ -2085,6 +2085,34 @@ class SchemaRDD(RDD):
         else:
             raise ValueError("Can only subtract another SchemaRDD")
 
+    def sample(self, withReplacement, fraction, seed=None):
+        """
+        Return a sampled subset of this SchemaRDD.
+
+        >>> srdd = sqlCtx.inferSchema(rdd)
+        >>> srdd.sample(False, 0.5, 97).count()
+        2L
+        """
+        assert fraction >= 0.0, "Negative fraction value: %s" % fraction
+        seed = seed if seed is not None else random.randint(0, sys.maxint)
+        rdd = self._jschema_rdd.sample(withReplacement, fraction, long(seed))
+        return SchemaRDD(rdd, self.sql_ctx)
+
+    def takeSample(self, withReplacement, num, seed=None):
+        """Return a fixed-size sampled subset of this SchemaRDD.
+
+        >>> srdd = sqlCtx.inferSchema(rdd)
+        >>> srdd.takeSample(False, 2, 97)
+        [Row(field1=3, field2=u'row3'), Row(field1=1, field2=u'row1')]
+        """
+        seed = seed if seed is not None else random.randint(0, sys.maxint)
+        with SCCallSiteSync(self.context) as css:
+            bytesInJava = self._jschema_rdd.baseSchemaRDD() \
+                .takeSampleToPython(withReplacement, num, long(seed)) \
+                .iterator()
+        cls = _create_cls(self.schema())
+        return map(cls, self._collect_iterator_through_file(bytesInJava))
+
 
 def _test():
     import doctest

http://git-wip-us.apache.org/repos/asf/spark/blob/fd41eb95/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala
index 7baf8ff..856b10f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala
@@ -438,6 +438,21 @@ class SchemaRDD(
   }
 
   /**
+   * Serializes the Array[Row] returned by SchemaRDD's takeSample(), using the same
+   * format as javaToPython and collectToPython. It is used by pyspark.
+   */
+  private[sql] def takeSampleToPython(
+      withReplacement: Boolean,
+      num: Int,
+      seed: Long): JList[Array[Byte]] = {
+    val fieldTypes = schema.fields.map(_.dataType)
+    val pickle = new Pickler
+    new java.util.ArrayList(this.takeSample(withReplacement, num, seed).map { row =>
+      EvaluatePython.rowToArray(row, fieldTypes)
+    }.grouped(100).map(batched => pickle.dumps(batched.toArray)).toIterable)
+  }
+
+  /**
    * Creates SchemaRDD by applying own schema to derived RDD. Typically used to wrap return value
    * of base RDD functions that do not change schema.
    *

http://git-wip-us.apache.org/repos/asf/spark/blob/fd41eb95/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSchemaRDD.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSchemaRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSchemaRDD.scala
index ac4844f..5b9c612 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSchemaRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSchemaRDD.scala
@@ -218,4 +218,10 @@ class JavaSchemaRDD(
    */
   def subtract(other: JavaSchemaRDD, p: Partitioner): JavaSchemaRDD =
     this.baseSchemaRDD.subtract(other.baseSchemaRDD, p).toJavaSchemaRDD
+
+  /**
+   * Return a SchemaRDD with a sampled version of the underlying dataset.
+   */
+  def sample(withReplacement: Boolean, fraction: Double, seed: Long): JavaSchemaRDD =
+    this.baseSchemaRDD.sample(withReplacement, fraction, seed).toJavaSchemaRDD
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org