You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by rx...@apache.org on 2015/05/11 20:35:19 UTC

spark git commit: [SPARK-7462] By default retain group by columns in aggregate

Repository: spark
Updated Branches:
  refs/heads/master 1b4655699 -> 0a4844f90


[SPARK-7462] By default retain group by columns in aggregate

Updated Java, Scala, Python, and R.

Author: Reynold Xin <rx...@databricks.com>
Author: Shivaram Venkataraman <sh...@cs.berkeley.edu>

Closes #5996 from rxin/groupby-retain and squashes the following commits:

aac7119 [Reynold Xin] Merge branch 'groupby-retain' of github.com:rxin/spark into groupby-retain
f6858f6 [Reynold Xin] Merge branch 'master' into groupby-retain
5f923c0 [Reynold Xin] Merge pull request #15 from shivaram/sparkr-groupby-retrain
c1de670 [Shivaram Venkataraman] Revert workaround in SparkR to retain grouped cols Based on reverting code added in commit https://github.com/amplab-extras/spark/commit/9a6be746efc9fafad88122fa2267862ef87aa0e1
b8b87e1 [Reynold Xin] Fixed DataFrameJoinSuite.
d910141 [Reynold Xin] Updated rest of the files
1e6e666 [Reynold Xin] [SPARK-7462] By default retain group by columns in aggregate


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/0a4844f9
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/0a4844f9
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/0a4844f9

Branch: refs/heads/master
Commit: 0a4844f90a712e796c9404b422cea76d21a5d2e3
Parents: 1b46556
Author: Reynold Xin <rx...@databricks.com>
Authored: Mon May 11 11:35:16 2015 -0700
Committer: Reynold Xin <rx...@databricks.com>
Committed: Mon May 11 11:35:16 2015 -0700

----------------------------------------------------------------------
 R/pkg/R/group.R                                 |   4 +-
 python/pyspark/sql/dataframe.py                 |   2 +-
 .../org/apache/spark/sql/GroupedData.scala      |  15 +-
 .../scala/org/apache/spark/sql/SQLConf.scala    |   6 +
 .../org/apache/spark/sql/api/r/SQLUtils.scala   |  11 --
 .../sql/execution/stat/StatFunctions.scala      |   2 +-
 .../spark/sql/DataFrameAggregateSuite.scala     | 193 +++++++++++++++++++
 .../apache/spark/sql/DataFrameJoinSuite.scala   |   4 +-
 .../org/apache/spark/sql/DataFrameSuite.scala   | 151 +--------------
 .../scala/org/apache/spark/sql/TestData.scala   |   2 -
 10 files changed, 218 insertions(+), 172 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/0a4844f9/R/pkg/R/group.R
----------------------------------------------------------------------
diff --git a/R/pkg/R/group.R b/R/pkg/R/group.R
index 5a7a8a2..b758481 100644
--- a/R/pkg/R/group.R
+++ b/R/pkg/R/group.R
@@ -102,9 +102,7 @@ setMethod("agg",
                 }
               }
               jcols <- lapply(cols, function(c) { c@jc })
-              # the GroupedData.agg(col, cols*) API does not contain grouping Column
-              sdf <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "aggWithGrouping",
-                                 x@sgd, listToSeq(jcols))
+              sdf <- callJMethod(x@sgd, "agg", jcols[[1]], listToSeq(jcols[-1]))
             } else {
               stop("agg can only support Column or character")
             }

http://git-wip-us.apache.org/repos/asf/spark/blob/0a4844f9/python/pyspark/sql/dataframe.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py
index a969799..c2fa6c8 100644
--- a/python/pyspark/sql/dataframe.py
+++ b/python/pyspark/sql/dataframe.py
@@ -1069,7 +1069,7 @@ class GroupedData(object):
 
         >>> from pyspark.sql import functions as F
         >>> gdf.agg(F.min(df.age)).collect()
-        [Row(MIN(age)=2), Row(MIN(age)=5)]
+        [Row(name=u'Alice', MIN(age)=2), Row(name=u'Bob', MIN(age)=5)]
         """
         assert exprs, "exprs should not be empty"
         if len(exprs) == 1 and isinstance(exprs[0], dict):

http://git-wip-us.apache.org/repos/asf/spark/blob/0a4844f9/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala
index 53ad673..003a620 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala
@@ -135,8 +135,9 @@ class GroupedData protected[sql](df: DataFrame, groupingExprs: Seq[Expression])
   }
 
   /**
-   * Compute aggregates by specifying a series of aggregate columns. Unlike other methods in this
-   * class, the resulting [[DataFrame]] won't automatically include the grouping columns.
+   * Compute aggregates by specifying a series of aggregate columns. Note that this function by
+   * default retains the grouping columns in its output. To not retain grouping columns, set
+   * `spark.sql.retainGroupColumns` to false.
    *
    * The available aggregate methods are defined in [[org.apache.spark.sql.functions]].
    *
@@ -158,7 +159,15 @@ class GroupedData protected[sql](df: DataFrame, groupingExprs: Seq[Expression])
       case expr: NamedExpression => expr
       case expr: Expression => Alias(expr, expr.prettyString)()
     }
-    DataFrame(df.sqlContext, Aggregate(groupingExprs, aggExprs, df.logicalPlan))
+    if (df.sqlContext.conf.dataFrameRetainGroupColumns) {
+      val retainedExprs = groupingExprs.map {
+        case expr: NamedExpression => expr
+        case expr: Expression => Alias(expr, expr.prettyString)()
+      }
+      DataFrame(df.sqlContext, Aggregate(groupingExprs, retainedExprs ++ aggExprs, df.logicalPlan))
+    } else {
+      DataFrame(df.sqlContext, Aggregate(groupingExprs, aggExprs, df.logicalPlan))
+    }
   }
 
   /**

http://git-wip-us.apache.org/repos/asf/spark/blob/0a4844f9/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
index 98a75bb..dcac97b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
@@ -74,6 +74,9 @@ private[spark] object SQLConf {
   // See SPARK-6231.
   val DATAFRAME_SELF_JOIN_AUTO_RESOLVE_AMBIGUITY = "spark.sql.selfJoinAutoResolveAmbiguity"
 
+  // Whether to retain group by columns or not in GroupedData.agg.
+  val DATAFRAME_RETAIN_GROUP_COLUMNS = "spark.sql.retainGroupColumns"
+
   val USE_SQL_SERIALIZER2 = "spark.sql.useSerializer2"
 
   val USE_JACKSON_STREAMING_API = "spark.sql.json.useJacksonStreamingAPI"
@@ -242,6 +245,9 @@ private[sql] class SQLConf extends Serializable with CatalystConf {
 
   private[spark] def dataFrameSelfJoinAutoResolveAmbiguity: Boolean =
     getConf(DATAFRAME_SELF_JOIN_AUTO_RESOLVE_AMBIGUITY, "true").toBoolean
+
+  private[spark] def dataFrameRetainGroupColumns: Boolean =
+    getConf(DATAFRAME_RETAIN_GROUP_COLUMNS, "true").toBoolean
   
   /** ********************** SQLConf functionality methods ************ */
 

http://git-wip-us.apache.org/repos/asf/spark/blob/0a4844f9/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala
index ae77f72..423ecdf 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala
@@ -72,17 +72,6 @@ private[r] object SQLUtils {
     sqlContext.createDataFrame(rowRDD, schema)
   }
 
-  // A helper to include grouping columns in Agg()
-  def aggWithGrouping(gd: GroupedData, exprs: Column*): DataFrame = {
-    val aggExprs = exprs.map { col =>
-      col.expr match {
-        case expr: NamedExpression => expr
-        case expr: Expression => Alias(expr, expr.simpleString)()
-      }
-    }
-    gd.toDF(aggExprs)
-  }
-
   def dfToRowRDD(df: DataFrame): JavaRDD[Array[Byte]] = {
     df.map(r => rowToRBytes(r))
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/0a4844f9/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala
index 71b7f6c..d22f5fd 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala
@@ -104,7 +104,7 @@ private[sql] object StatFunctions extends Logging {
   /** Generate a table of frequencies for the elements of two columns. */
   private[sql] def crossTabulate(df: DataFrame, col1: String, col2: String): DataFrame = {
     val tableName = s"${col1}_$col2"
-    val counts = df.groupBy(col1, col2).agg(col(col1), col(col2), count("*")).take(1e6.toInt)
+    val counts = df.groupBy(col1, col2).agg(count("*")).take(1e6.toInt)
     if (counts.length == 1e6.toInt) {
       logWarning("The maximum limit of 1e6 pairs have been collected, which may not be all of " +
         "the pairs. Please try reducing the amount of distinct items in your columns.")

http://git-wip-us.apache.org/repos/asf/spark/blob/0a4844f9/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
new file mode 100644
index 0000000..35a574f
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
@@ -0,0 +1,193 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import org.apache.spark.sql.TestData._
+import org.apache.spark.sql.functions._
+import org.apache.spark.sql.test.TestSQLContext
+import org.apache.spark.sql.test.TestSQLContext.implicits._
+import org.apache.spark.sql.types.DecimalType
+
+
+class DataFrameAggregateSuite extends QueryTest {
+
+  test("groupBy") {
+    checkAnswer(
+      testData2.groupBy("a").agg(sum($"b")),
+      Seq(Row(1, 3), Row(2, 3), Row(3, 3))
+    )
+    checkAnswer(
+      testData2.groupBy("a").agg(sum($"b").as("totB")).agg(sum('totB)),
+      Row(9)
+    )
+    checkAnswer(
+      testData2.groupBy("a").agg(count("*")),
+      Row(1, 2) :: Row(2, 2) :: Row(3, 2) :: Nil
+    )
+    checkAnswer(
+      testData2.groupBy("a").agg(Map("*" -> "count")),
+      Row(1, 2) :: Row(2, 2) :: Row(3, 2) :: Nil
+    )
+    checkAnswer(
+      testData2.groupBy("a").agg(Map("b" -> "sum")),
+      Row(1, 3) :: Row(2, 3) :: Row(3, 3) :: Nil
+    )
+
+    val df1 = Seq(("a", 1, 0, "b"), ("b", 2, 4, "c"), ("a", 2, 3, "d"))
+      .toDF("key", "value1", "value2", "rest")
+
+    checkAnswer(
+      df1.groupBy("key").min(),
+      df1.groupBy("key").min("value1", "value2").collect()
+    )
+    checkAnswer(
+      df1.groupBy("key").min("value2"),
+      Seq(Row("a", 0), Row("b", 4))
+    )
+  }
+
+  test("spark.sql.retainGroupColumns config") {
+    checkAnswer(
+      testData2.groupBy("a").agg(sum($"b")),
+      Seq(Row(1, 3), Row(2, 3), Row(3, 3))
+    )
+
+    TestSQLContext.conf.setConf("spark.sql.retainGroupColumns", "false")
+    checkAnswer(
+      testData2.groupBy("a").agg(sum($"b")),
+      Seq(Row(3), Row(3), Row(3))
+    )
+    TestSQLContext.conf.setConf("spark.sql.retainGroupColumns", "true")
+  }
+
+  test("agg without groups") {
+    checkAnswer(
+      testData2.agg(sum('b)),
+      Row(9)
+    )
+  }
+
+  test("average") {
+    checkAnswer(
+      testData2.agg(avg('a)),
+      Row(2.0))
+
+    // Also check mean
+    checkAnswer(
+      testData2.agg(mean('a)),
+      Row(2.0))
+
+    checkAnswer(
+      testData2.agg(avg('a), sumDistinct('a)), // non-partial
+      Row(2.0, 6.0) :: Nil)
+
+    checkAnswer(
+      decimalData.agg(avg('a)),
+      Row(new java.math.BigDecimal(2.0)))
+    checkAnswer(
+      decimalData.agg(avg('a), sumDistinct('a)), // non-partial
+      Row(new java.math.BigDecimal(2.0), new java.math.BigDecimal(6)) :: Nil)
+
+    checkAnswer(
+      decimalData.agg(avg('a cast DecimalType(10, 2))),
+      Row(new java.math.BigDecimal(2.0)))
+    // non-partial
+    checkAnswer(
+      decimalData.agg(avg('a cast DecimalType(10, 2)), sumDistinct('a cast DecimalType(10, 2))),
+      Row(new java.math.BigDecimal(2.0), new java.math.BigDecimal(6)) :: Nil)
+  }
+
+  test("null average") {
+    checkAnswer(
+      testData3.agg(avg('b)),
+      Row(2.0))
+
+    checkAnswer(
+      testData3.agg(avg('b), countDistinct('b)),
+      Row(2.0, 1))
+
+    checkAnswer(
+      testData3.agg(avg('b), sumDistinct('b)), // non-partial
+      Row(2.0, 2.0))
+  }
+
+  test("zero average") {
+    val emptyTableData = Seq.empty[(Int, Int)].toDF("a", "b")
+    checkAnswer(
+      emptyTableData.agg(avg('a)),
+      Row(null))
+
+    checkAnswer(
+      emptyTableData.agg(avg('a), sumDistinct('b)), // non-partial
+      Row(null, null))
+  }
+
+  test("count") {
+    assert(testData2.count() === testData2.map(_ => 1).count())
+
+    checkAnswer(
+      testData2.agg(count('a), sumDistinct('a)), // non-partial
+      Row(6, 6.0))
+  }
+
+  test("null count") {
+    checkAnswer(
+      testData3.groupBy('a).agg(count('b)),
+      Seq(Row(1,0), Row(2, 1))
+    )
+
+    checkAnswer(
+      testData3.groupBy('a).agg(count('a + 'b)),
+      Seq(Row(1,0), Row(2, 1))
+    )
+
+    checkAnswer(
+      testData3.agg(count('a), count('b), count(lit(1)), countDistinct('a), countDistinct('b)),
+      Row(2, 1, 2, 2, 1)
+    )
+
+    checkAnswer(
+      testData3.agg(count('b), countDistinct('b), sumDistinct('b)), // non-partial
+      Row(1, 1, 2)
+    )
+  }
+
+  test("zero count") {
+    val emptyTableData = Seq.empty[(Int, Int)].toDF("a", "b")
+    assert(emptyTableData.count() === 0)
+
+    checkAnswer(
+      emptyTableData.agg(count('a), sumDistinct('a)), // non-partial
+      Row(0, null))
+  }
+
+  test("zero sum") {
+    val emptyTableData = Seq.empty[(Int, Int)].toDF("a", "b")
+    checkAnswer(
+      emptyTableData.agg(sum('a)),
+      Row(null))
+  }
+
+  test("zero sum distinct") {
+    val emptyTableData = Seq.empty[(Int, Int)].toDF("a", "b")
+    checkAnswer(
+      emptyTableData.agg(sumDistinct('a)),
+      Row(null))
+  }
+
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/0a4844f9/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala
index f005f55..787f3f1 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala
@@ -77,8 +77,8 @@ class DataFrameJoinSuite extends QueryTest {
       df.join(df, df("key") === df("key") && df("value") === 1),
       Row(1, "1", 1, "1") :: Nil)
 
-    val left = df.groupBy("key").agg($"key", count("*"))
-    val right = df.groupBy("key").agg($"key", sum("key"))
+    val left = df.groupBy("key").agg(count("*"))
+    val right = df.groupBy("key").agg(sum("key"))
     checkAnswer(
       left.join(right, left("key") === right("key")),
       Row(1, 1, 1, 1) :: Row(2, 1, 2, 2) :: Nil)

http://git-wip-us.apache.org/repos/asf/spark/blob/0a4844f9/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index cf590cb..7552c12 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -22,7 +22,6 @@ import scala.language.postfixOps
 import org.apache.spark.sql.functions._
 import org.apache.spark.sql.types._
 import org.apache.spark.sql.test.{ExamplePointUDT, ExamplePoint, TestSQLContext}
-import org.apache.spark.sql.test.TestSQLContext.logicalPlanToSparkQuery
 import org.apache.spark.sql.test.TestSQLContext.implicits._
 
 
@@ -63,7 +62,7 @@ class DataFrameSuite extends QueryTest {
     val df = Seq((1,(1,1))).toDF()
 
     checkAnswer(
-      df.groupBy("_1").agg(col("_1"), sum("_2._1")).toDF("key", "total"),
+      df.groupBy("_1").agg(sum("_2._1")).toDF("key", "total"),
       Row(1, 1) :: Nil)
   }
 
@@ -128,7 +127,7 @@ class DataFrameSuite extends QueryTest {
       df2
         .select('_1 as 'letter, 'number)
         .groupBy('letter)
-        .agg('letter, countDistinct('number)),
+        .agg(countDistinct('number)),
       Row("a", 3) :: Row("b", 2) :: Row("c", 1) :: Nil
     )
   }
@@ -165,48 +164,6 @@ class DataFrameSuite extends QueryTest {
       testData.select('key).collect().toSeq)
   }
 
-  test("groupBy") {
-    checkAnswer(
-      testData2.groupBy("a").agg($"a", sum($"b")),
-      Seq(Row(1, 3), Row(2, 3), Row(3, 3))
-    )
-    checkAnswer(
-      testData2.groupBy("a").agg($"a", sum($"b").as("totB")).agg(sum('totB)),
-      Row(9)
-    )
-    checkAnswer(
-      testData2.groupBy("a").agg(col("a"), count("*")),
-      Row(1, 2) :: Row(2, 2) :: Row(3, 2) :: Nil
-    )
-    checkAnswer(
-      testData2.groupBy("a").agg(Map("*" -> "count")),
-      Row(1, 2) :: Row(2, 2) :: Row(3, 2) :: Nil
-    )
-    checkAnswer(
-      testData2.groupBy("a").agg(Map("b" -> "sum")),
-      Row(1, 3) :: Row(2, 3) :: Row(3, 3) :: Nil
-    )
-
-    val df1 = Seq(("a", 1, 0, "b"), ("b", 2, 4, "c"), ("a", 2, 3, "d"))
-      .toDF("key", "value1", "value2", "rest")
-
-    checkAnswer(
-      df1.groupBy("key").min(),
-      df1.groupBy("key").min("value1", "value2").collect()
-    )
-    checkAnswer(
-      df1.groupBy("key").min("value2"),
-      Seq(Row("a", 0), Row("b", 4))
-    )
-  }
-
-  test("agg without groups") {
-    checkAnswer(
-      testData2.agg(sum('b)),
-      Row(9)
-    )
-  }
-
   test("convert $\"attribute name\" into unresolved attribute") {
     checkAnswer(
       testData.where($"key" === lit(1)).select($"value"),
@@ -303,110 +260,6 @@ class DataFrameSuite extends QueryTest {
       mapData.take(1).map(r => Row.fromSeq(r.productIterator.toSeq)))
   }
 
-  test("average") {
-    checkAnswer(
-      testData2.agg(avg('a)),
-      Row(2.0))
-
-    // Also check mean
-    checkAnswer(
-      testData2.agg(mean('a)),
-      Row(2.0))
-
-    checkAnswer(
-      testData2.agg(avg('a), sumDistinct('a)), // non-partial
-      Row(2.0, 6.0) :: Nil)
-
-    checkAnswer(
-      decimalData.agg(avg('a)),
-      Row(new java.math.BigDecimal(2.0)))
-    checkAnswer(
-      decimalData.agg(avg('a), sumDistinct('a)), // non-partial
-      Row(new java.math.BigDecimal(2.0), new java.math.BigDecimal(6)) :: Nil)
-
-    checkAnswer(
-      decimalData.agg(avg('a cast DecimalType(10, 2))),
-      Row(new java.math.BigDecimal(2.0)))
-    // non-partial
-    checkAnswer(
-      decimalData.agg(avg('a cast DecimalType(10, 2)), sumDistinct('a cast DecimalType(10, 2))),
-      Row(new java.math.BigDecimal(2.0), new java.math.BigDecimal(6)) :: Nil)
-  }
-
-  test("null average") {
-    checkAnswer(
-      testData3.agg(avg('b)),
-      Row(2.0))
-
-    checkAnswer(
-      testData3.agg(avg('b), countDistinct('b)),
-      Row(2.0, 1))
-
-    checkAnswer(
-      testData3.agg(avg('b), sumDistinct('b)), // non-partial
-      Row(2.0, 2.0))
-  }
-
-  test("zero average") {
-    checkAnswer(
-      emptyTableData.agg(avg('a)),
-      Row(null))
-
-    checkAnswer(
-      emptyTableData.agg(avg('a), sumDistinct('b)), // non-partial
-      Row(null, null))
-  }
-
-  test("count") {
-    assert(testData2.count() === testData2.map(_ => 1).count())
-
-    checkAnswer(
-      testData2.agg(count('a), sumDistinct('a)), // non-partial
-      Row(6, 6.0))
-  }
-
-  test("null count") {
-    checkAnswer(
-      testData3.groupBy('a).agg('a, count('b)),
-      Seq(Row(1,0), Row(2, 1))
-    )
-
-    checkAnswer(
-      testData3.groupBy('a).agg('a, count('a + 'b)),
-      Seq(Row(1,0), Row(2, 1))
-    )
-
-    checkAnswer(
-      testData3.agg(count('a), count('b), count(lit(1)), countDistinct('a), countDistinct('b)),
-      Row(2, 1, 2, 2, 1)
-    )
-
-    checkAnswer(
-      testData3.agg(count('b), countDistinct('b), sumDistinct('b)), // non-partial
-      Row(1, 1, 2)
-    )
-  }
-
-  test("zero count") {
-    assert(emptyTableData.count() === 0)
-
-    checkAnswer(
-      emptyTableData.agg(count('a), sumDistinct('a)), // non-partial
-      Row(0, null))
-  }
-
-  test("zero sum") {
-    checkAnswer(
-      emptyTableData.agg(sum('a)),
-      Row(null))
-  }
-
-  test("zero sum distinct") {
-    checkAnswer(
-      emptyTableData.agg(sumDistinct('a)),
-      Row(null))
-  }
-
   test("except") {
     checkAnswer(
       lowerCaseData.except(upperCaseData),

http://git-wip-us.apache.org/repos/asf/spark/blob/0a4844f9/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala
index 225b51b..446771a 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala
@@ -86,8 +86,6 @@ object TestData {
       TestData3(2, Some(2)) :: Nil).toDF()
   testData3.registerTempTable("testData3")
 
-  val emptyTableData = logical.LocalRelation($"a".int, $"b".int)
-
   case class UpperCaseData(N: Int, L: String)
   val upperCaseData =
     TestSQLContext.sparkContext.parallelize(


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org