You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2016/08/04 11:45:59 UTC
spark git commit: [SPARK-16853][SQL] fixes encoder error in DataSet
typed select
Repository: spark
Updated Branches:
refs/heads/master 43f4fd6f9 -> 9d7a47406
[SPARK-16853][SQL] fixes encoder error in DataSet typed select
## What changes were proposed in this pull request?
For DataSet typed select:
```
def select[U1: Encoder](c1: TypedColumn[T, U1]): Dataset[U1]
```
If type T is a case class or a tuple class that is not atomic, the resulting logical plan's schema will mismatch with `Dataset[T]` encoder's schema, which will cause encoder error and throw AnalysisException.
### Before change:
```
scala> case class A(a: Int, b: Int)
scala> Seq((0, A(1,2))).toDS.select($"_2".as[A])
org.apache.spark.sql.AnalysisException: cannot resolve '`a`' given input columns: [_2];
..
```
### After change:
```
scala> case class A(a: Int, b: Int)
scala> Seq((0, A(1,2))).toDS.select($"_2".as[A]).show
+---+---+
| a| b|
+---+---+
| 1| 2|
+---+---+
```
## How was this patch tested?
Unit test.
Author: Sean Zhong <se...@databricks.com>
Closes #14474 from clockfly/SPARK-16853.
Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/9d7a4740
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/9d7a4740
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/9d7a4740
Branch: refs/heads/master
Commit: 9d7a47406ed538f0005cdc7a62bc6e6f20634815
Parents: 43f4fd6
Author: Sean Zhong <se...@databricks.com>
Authored: Thu Aug 4 19:45:47 2016 +0800
Committer: Wenchen Fan <we...@databricks.com>
Committed: Thu Aug 4 19:45:47 2016 +0800
----------------------------------------------------------------------
project/MimaExcludes.scala | 4 +++-
.../catalyst/encoders/ExpressionEncoder.scala | 4 ++++
.../scala/org/apache/spark/sql/Dataset.scala | 20 +++++++++++---------
.../org/apache/spark/sql/DatasetSuite.scala | 11 +++++++++++
4 files changed, 29 insertions(+), 10 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/spark/blob/9d7a4740/project/MimaExcludes.scala
----------------------------------------------------------------------
diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala
index 5606155..a201d7f 100644
--- a/project/MimaExcludes.scala
+++ b/project/MimaExcludes.scala
@@ -38,7 +38,9 @@ object MimaExcludes {
lazy val v21excludes = v20excludes ++ {
Seq(
// [SPARK-16199][SQL] Add a method to list the referenced columns in data source Filter
- ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.sources.Filter.references")
+ ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.sources.Filter.references"),
+ // [SPARK-16853][SQL] Fixes encoder error in DataSet typed select
+ ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.sql.Dataset.select")
)
}
http://git-wip-us.apache.org/repos/asf/spark/blob/9d7a4740/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
index 1fac26c..b96b744 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
@@ -169,6 +169,10 @@ object ExpressionEncoder {
ClassTag(cls))
}
+ // Tuple1
+ def tuple[T](e: ExpressionEncoder[T]): ExpressionEncoder[Tuple1[T]] =
+ tuple(Seq(e)).asInstanceOf[ExpressionEncoder[Tuple1[T]]]
+
def tuple[T1, T2](
e1: ExpressionEncoder[T1],
e2: ExpressionEncoder[T2]): ExpressionEncoder[(T1, T2)] =
http://git-wip-us.apache.org/repos/asf/spark/blob/9d7a4740/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index 8b6443c..306ca77 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -1061,15 +1061,17 @@ class Dataset[T] private[sql](
* @since 1.6.0
*/
@Experimental
- def select[U1: Encoder](c1: TypedColumn[T, U1]): Dataset[U1] = {
- new Dataset[U1](
- sparkSession,
- Project(
- c1.withInputType(
- exprEnc.deserializer,
- logicalPlan.output).named :: Nil,
- logicalPlan),
- implicitly[Encoder[U1]])
+ def select[U1](c1: TypedColumn[T, U1]): Dataset[U1] = {
+ implicit val encoder = c1.encoder
+ val project = Project(c1.withInputType(exprEnc.deserializer, logicalPlan.output).named :: Nil,
+ logicalPlan)
+
+ if (encoder.flat) {
+ new Dataset[U1](sparkSession, project, encoder)
+ } else {
+ // Flattens inner fields of U1
+ new Dataset[Tuple1[U1]](sparkSession, project, ExpressionEncoder.tuple(encoder)).map(_._1)
+ }
}
/**
http://git-wip-us.apache.org/repos/asf/spark/blob/9d7a4740/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
index 7e3b7b6..8a756fd 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
@@ -184,6 +184,17 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
2, 3, 4)
}
+ test("SPARK-16853: select, case class and tuple") {
+ val ds = Seq(("a", 1), ("b", 2), ("c", 3)).toDS()
+ checkDataset(
+ ds.select(expr("struct(_2, _2)").as[(Int, Int)]): Dataset[(Int, Int)],
+ (1, 1), (2, 2), (3, 3))
+
+ checkDataset(
+ ds.select(expr("named_struct('a', _1, 'b', _2)").as[ClassData]): Dataset[ClassData],
+ ClassData("a", 1), ClassData("b", 2), ClassData("c", 3))
+ }
+
test("select 2") {
val ds = Seq(("a", 1), ("b", 2), ("c", 3)).toDS()
checkDataset(
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org