You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ma...@apache.org on 2015/12/01 19:23:01 UTC
spark git commit: [SPARK-12068][SQL] use a single column in
Dataset.groupBy and count will fail
Repository: spark
Updated Branches:
refs/heads/master 69dbe6b40 -> 8ddc55f1d
[SPARK-12068][SQL] use a single column in Dataset.groupBy and count will fail
The reason is that, for a single culumn `RowEncoder`(or a single field product encoder), when we use it as the encoder for grouping key, we should also combine the grouping attributes, although there is only one grouping attribute.
Author: Wenchen Fan <we...@databricks.com>
Closes #10059 from cloud-fan/bug.
Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/8ddc55f1
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/8ddc55f1
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/8ddc55f1
Branch: refs/heads/master
Commit: 8ddc55f1d582cccc3ca135510b2ea776e889e481
Parents: 69dbe6b
Author: Wenchen Fan <we...@databricks.com>
Authored: Tue Dec 1 10:22:55 2015 -0800
Committer: Michael Armbrust <mi...@databricks.com>
Committed: Tue Dec 1 10:22:55 2015 -0800
----------------------------------------------------------------------
.../scala/org/apache/spark/sql/Dataset.scala | 2 +-
.../org/apache/spark/sql/GroupedDataset.scala | 7 ++++---
.../org/apache/spark/sql/DatasetSuite.scala | 19 +++++++++++++++++++
.../scala/org/apache/spark/sql/QueryTest.scala | 6 +++---
4 files changed, 27 insertions(+), 7 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/spark/blob/8ddc55f1/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index da46001..c357f88 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -70,7 +70,7 @@ class Dataset[T] private[sql](
* implicit so that we can use it when constructing new [[Dataset]] objects that have the same
* object type (that will be possibly resolved to a different schema).
*/
- private implicit val unresolvedTEncoder: ExpressionEncoder[T] = encoderFor(tEncoder)
+ private[sql] implicit val unresolvedTEncoder: ExpressionEncoder[T] = encoderFor(tEncoder)
/** The encoder for this [[Dataset]] that has been resolved to its output schema. */
private[sql] val resolvedTEncoder: ExpressionEncoder[T] =
http://git-wip-us.apache.org/repos/asf/spark/blob/8ddc55f1/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala
index a10a893..4bf0b25 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala
@@ -228,10 +228,11 @@ class GroupedDataset[K, V] private[sql](
val namedColumns =
columns.map(
_.withInputType(resolvedVEncoder, dataAttributes).named)
- val keyColumn = if (groupingAttributes.length > 1) {
- Alias(CreateStruct(groupingAttributes), "key")()
- } else {
+ val keyColumn = if (resolvedKEncoder.flat) {
+ assert(groupingAttributes.length == 1)
groupingAttributes.head
+ } else {
+ Alias(CreateStruct(groupingAttributes), "key")()
}
val aggregate = Aggregate(groupingAttributes, keyColumn +: namedColumns, logicalPlan)
val execution = new QueryExecution(sqlContext, aggregate)
http://git-wip-us.apache.org/repos/asf/spark/blob/8ddc55f1/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
index 7d53918..a2c8d20 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
@@ -272,6 +272,16 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
3 -> "abcxyz", 5 -> "hello")
}
+ test("groupBy single field class, count") {
+ val ds = Seq("abc", "xyz", "hello").toDS()
+ val count = ds.groupBy(s => Tuple1(s.length)).count()
+
+ checkAnswer(
+ count,
+ (Tuple1(3), 2L), (Tuple1(5), 1L)
+ )
+ }
+
test("groupBy columns, map") {
val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS()
val grouped = ds.groupBy($"_1")
@@ -282,6 +292,15 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
("a", 30), ("b", 3), ("c", 1))
}
+ test("groupBy columns, count") {
+ val ds = Seq("a" -> 1, "b" -> 1, "a" -> 2).toDS()
+ val count = ds.groupBy($"_1").count()
+
+ checkAnswer(
+ count,
+ (Row("a"), 2L), (Row("b"), 1L))
+ }
+
test("groupBy columns asKey, map") {
val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS()
val grouped = ds.groupBy($"_1").keyAs[String]
http://git-wip-us.apache.org/repos/asf/spark/blob/8ddc55f1/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
index 6ea1fe4..8f476dd 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
@@ -64,12 +64,12 @@ abstract class QueryTest extends PlanTest {
* for cases where reordering is done on fields. For such tests, user `checkDecoding` instead
* which performs a subset of the checks done by this function.
*/
- protected def checkAnswer[T : Encoder](
- ds: => Dataset[T],
+ protected def checkAnswer[T](
+ ds: Dataset[T],
expectedAnswer: T*): Unit = {
checkAnswer(
ds.toDF(),
- sqlContext.createDataset(expectedAnswer).toDF().collect().toSeq)
+ sqlContext.createDataset(expectedAnswer)(ds.unresolvedTEncoder).toDF().collect().toSeq)
checkDecoding(ds, expectedAnswer: _*)
}
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org