You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by rx...@apache.org on 2014/05/25 09:06:55 UTC

git commit: [SPARK-1822] SchemaRDD.count() should use query optimizer

Repository: spark
Updated Branches:
  refs/heads/master 6e9fb6320 -> 6052db9dc


[SPARK-1822] SchemaRDD.count() should use query optimizer

Author: Kan Zhang <kz...@apache.org>

Closes #841 from kanzhang/SPARK-1822 and squashes the following commits:

2f8072a [Kan Zhang] [SPARK-1822] Minor style update
cf4baa4 [Kan Zhang] [SPARK-1822] Adding Scaladoc
e67c910 [Kan Zhang] [SPARK-1822] SchemaRDD.count() should use optimizer


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/6052db9d
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/6052db9d
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/6052db9d

Branch: refs/heads/master
Commit: 6052db9dc10c996215658485e805200e4f0cf549
Parents: 6e9fb63
Author: Kan Zhang <kz...@apache.org>
Authored: Sun May 25 00:06:42 2014 -0700
Committer: Reynold Xin <rx...@apache.org>
Committed: Sun May 25 00:06:42 2014 -0700

----------------------------------------------------------------------
 python/pyspark/sql.py                                 | 14 +++++++++++++-
 .../spark/sql/catalyst/expressions/aggregates.scala   |  6 +++---
 .../main/scala/org/apache/spark/sql/SchemaRDD.scala   |  9 +++++++++
 .../scala/org/apache/spark/sql/DslQuerySuite.scala    |  9 +++++----
 .../test/scala/org/apache/spark/sql/TestData.scala    |  2 ++
 5 files changed, 32 insertions(+), 8 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/6052db9d/python/pyspark/sql.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py
index bbe69e7..f2001af 100644
--- a/python/pyspark/sql.py
+++ b/python/pyspark/sql.py
@@ -268,7 +268,7 @@ class SchemaRDD(RDD):
     def _jrdd(self):
         """
         Lazy evaluation of PythonRDD object. Only done when a user calls methods defined by the
-        L{pyspark.rdd.RDD} super class (map, count, etc.).
+        L{pyspark.rdd.RDD} super class (map, filter, etc.).
         """
         if not hasattr(self, '_lazy_jrdd'):
             self._lazy_jrdd = self._toPython()._jrdd
@@ -321,6 +321,18 @@ class SchemaRDD(RDD):
         """
         self._jschema_rdd.saveAsTable(tableName)
 
+    def count(self):
+        """
+        Return the number of elements in this RDD.
+
+        >>> srdd = sqlCtx.inferSchema(rdd)
+        >>> srdd.count()
+        3L
+        >>> srdd.count() == srdd.map(lambda x: x).count()
+        True
+        """
+        return self._jschema_rdd.count()
+
     def _toPython(self):
         # We have to import the Row class explicitly, so that the reference Pickler has is
         # pyspark.sql.Row instead of __main__.Row

http://git-wip-us.apache.org/repos/asf/spark/blob/6052db9d/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala
index 5dbaaa3..1bcd4e2 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala
@@ -151,7 +151,7 @@ case class MaxFunction(expr: Expression, base: AggregateExpression) extends Aggr
 case class Count(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] {
   override def references = child.references
   override def nullable = false
-  override def dataType = IntegerType
+  override def dataType = LongType
   override def toString = s"COUNT($child)"
 
   override def asPartial: SplitEvaluation = {
@@ -295,12 +295,12 @@ case class AverageFunction(expr: Expression, base: AggregateExpression)
 case class CountFunction(expr: Expression, base: AggregateExpression) extends AggregateFunction {
   def this() = this(null, null) // Required for serialization.
 
-  var count: Int = _
+  var count: Long = _
 
   override def update(input: Row): Unit = {
     val evaluatedExpr = expr.map(_.eval(input))
     if (evaluatedExpr.map(_ != null).reduceLeft(_ || _)) {
-      count += 1
+      count += 1L
     }
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/6052db9d/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala
index 2569815..452da3d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala
@@ -276,6 +276,15 @@ class SchemaRDD(
 
   /**
    * :: Experimental ::
+   * Overriding base RDD implementation to leverage query optimizer
+   */
+  @Experimental
+  override def count(): Long = {
+    groupBy()(Count(Literal(1))).collect().head.getLong(0)
+  }
+
+  /**
+   * :: Experimental ::
    * Applies the given Generator, or table generating function, to this relation.
    *
    * @param generator A table generating function.  The API for such functions is likely to change

http://git-wip-us.apache.org/repos/asf/spark/blob/6052db9d/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala
index f43e98d..233132a 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala
@@ -108,10 +108,7 @@ class DslQuerySuite extends QueryTest {
   }
 
   test("count") {
-    checkAnswer(
-      testData2.groupBy()(Count(1)),
-      testData2.count()
-    )
+    assert(testData2.count() === testData2.map(_ => 1).count())
   }
 
   test("null count") {
@@ -126,6 +123,10 @@ class DslQuerySuite extends QueryTest {
     )
   }
 
+  test("zero count") {
+    assert(testData4.count() === 0)
+  }
+
   test("inner join where, one match per row") {
     checkAnswer(
       upperCaseData.join(lowerCaseData, Inner).where('n === 'N),

http://git-wip-us.apache.org/repos/asf/spark/blob/6052db9d/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala
index 1aca387..b1eecb4 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala
@@ -47,6 +47,8 @@ object TestData {
       (1, null) ::
       (2, 2) :: Nil)
 
+  val testData4 = logical.LocalRelation('a.int, 'b.int)
+
   case class UpperCaseData(N: Int, L: String)
   val upperCaseData =
     TestSQLContext.sparkContext.parallelize(