You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ma...@apache.org on 2015/12/01 19:23:01 UTC

spark git commit: [SPARK-12068][SQL] use a single column in Dataset.groupBy and count will fail

Repository: spark
Updated Branches:
  refs/heads/master 69dbe6b40 -> 8ddc55f1d


[SPARK-12068][SQL] use a single column in Dataset.groupBy and count will fail

The reason is that, for a single culumn `RowEncoder`(or a single field product encoder), when we use it as the encoder for grouping key, we should also combine the grouping attributes, although there is only one grouping attribute.

Author: Wenchen Fan <we...@databricks.com>

Closes #10059 from cloud-fan/bug.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/8ddc55f1
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/8ddc55f1
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/8ddc55f1

Branch: refs/heads/master
Commit: 8ddc55f1d582cccc3ca135510b2ea776e889e481
Parents: 69dbe6b
Author: Wenchen Fan <we...@databricks.com>
Authored: Tue Dec 1 10:22:55 2015 -0800
Committer: Michael Armbrust <mi...@databricks.com>
Committed: Tue Dec 1 10:22:55 2015 -0800

----------------------------------------------------------------------
 .../scala/org/apache/spark/sql/Dataset.scala     |  2 +-
 .../org/apache/spark/sql/GroupedDataset.scala    |  7 ++++---
 .../org/apache/spark/sql/DatasetSuite.scala      | 19 +++++++++++++++++++
 .../scala/org/apache/spark/sql/QueryTest.scala   |  6 +++---
 4 files changed, 27 insertions(+), 7 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/8ddc55f1/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index da46001..c357f88 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -70,7 +70,7 @@ class Dataset[T] private[sql](
    * implicit so that we can use it when constructing new [[Dataset]] objects that have the same
    * object type (that will be possibly resolved to a different schema).
    */
-  private implicit val unresolvedTEncoder: ExpressionEncoder[T] = encoderFor(tEncoder)
+  private[sql] implicit val unresolvedTEncoder: ExpressionEncoder[T] = encoderFor(tEncoder)
 
   /** The encoder for this [[Dataset]] that has been resolved to its output schema. */
   private[sql] val resolvedTEncoder: ExpressionEncoder[T] =

http://git-wip-us.apache.org/repos/asf/spark/blob/8ddc55f1/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala
index a10a893..4bf0b25 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala
@@ -228,10 +228,11 @@ class GroupedDataset[K, V] private[sql](
     val namedColumns =
       columns.map(
         _.withInputType(resolvedVEncoder, dataAttributes).named)
-    val keyColumn = if (groupingAttributes.length > 1) {
-      Alias(CreateStruct(groupingAttributes), "key")()
-    } else {
+    val keyColumn = if (resolvedKEncoder.flat) {
+      assert(groupingAttributes.length == 1)
       groupingAttributes.head
+    } else {
+      Alias(CreateStruct(groupingAttributes), "key")()
     }
     val aggregate = Aggregate(groupingAttributes, keyColumn +: namedColumns, logicalPlan)
     val execution = new QueryExecution(sqlContext, aggregate)

http://git-wip-us.apache.org/repos/asf/spark/blob/8ddc55f1/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
index 7d53918..a2c8d20 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
@@ -272,6 +272,16 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
       3 -> "abcxyz", 5 -> "hello")
   }
 
+  test("groupBy single field class, count") {
+    val ds = Seq("abc", "xyz", "hello").toDS()
+    val count = ds.groupBy(s => Tuple1(s.length)).count()
+
+    checkAnswer(
+      count,
+      (Tuple1(3), 2L), (Tuple1(5), 1L)
+    )
+  }
+
   test("groupBy columns, map") {
     val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS()
     val grouped = ds.groupBy($"_1")
@@ -282,6 +292,15 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
       ("a", 30), ("b", 3), ("c", 1))
   }
 
+  test("groupBy columns, count") {
+    val ds = Seq("a" -> 1, "b" -> 1, "a" -> 2).toDS()
+    val count = ds.groupBy($"_1").count()
+
+    checkAnswer(
+      count,
+      (Row("a"), 2L), (Row("b"), 1L))
+  }
+
   test("groupBy columns asKey, map") {
     val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS()
     val grouped = ds.groupBy($"_1").keyAs[String]

http://git-wip-us.apache.org/repos/asf/spark/blob/8ddc55f1/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
index 6ea1fe4..8f476dd 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
@@ -64,12 +64,12 @@ abstract class QueryTest extends PlanTest {
    *    for cases where reordering is done on fields.  For such tests, user `checkDecoding` instead
    *    which performs a subset of the checks done by this function.
    */
-  protected def checkAnswer[T : Encoder](
-      ds: => Dataset[T],
+  protected def checkAnswer[T](
+      ds: Dataset[T],
       expectedAnswer: T*): Unit = {
     checkAnswer(
       ds.toDF(),
-      sqlContext.createDataset(expectedAnswer).toDF().collect().toSeq)
+      sqlContext.createDataset(expectedAnswer)(ds.unresolvedTEncoder).toDF().collect().toSeq)
 
     checkDecoding(ds, expectedAnswer: _*)
   }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org