You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2016/08/04 11:45:59 UTC

spark git commit: [SPARK-16853][SQL] fixes encoder error in DataSet typed select

Repository: spark
Updated Branches:
  refs/heads/master 43f4fd6f9 -> 9d7a47406


[SPARK-16853][SQL] fixes encoder error in DataSet typed select

## What changes were proposed in this pull request?

For DataSet typed select:
```
def select[U1: Encoder](c1: TypedColumn[T, U1]): Dataset[U1]
```
If type T is a case class or a tuple class that is not atomic, the resulting logical plan's schema will mismatch with `Dataset[T]` encoder's schema, which will cause encoder error and throw AnalysisException.

### Before change:
```
scala> case class A(a: Int, b: Int)
scala> Seq((0, A(1,2))).toDS.select($"_2".as[A])
org.apache.spark.sql.AnalysisException: cannot resolve '`a`' given input columns: [_2];
..
```

### After change:
```
scala> case class A(a: Int, b: Int)
scala> Seq((0, A(1,2))).toDS.select($"_2".as[A]).show
+---+---+
|  a|  b|
+---+---+
|  1|  2|
+---+---+
```

## How was this patch tested?

Unit test.

Author: Sean Zhong <se...@databricks.com>

Closes #14474 from clockfly/SPARK-16853.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/9d7a4740
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/9d7a4740
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/9d7a4740

Branch: refs/heads/master
Commit: 9d7a47406ed538f0005cdc7a62bc6e6f20634815
Parents: 43f4fd6
Author: Sean Zhong <se...@databricks.com>
Authored: Thu Aug 4 19:45:47 2016 +0800
Committer: Wenchen Fan <we...@databricks.com>
Committed: Thu Aug 4 19:45:47 2016 +0800

----------------------------------------------------------------------
 project/MimaExcludes.scala                      |  4 +++-
 .../catalyst/encoders/ExpressionEncoder.scala   |  4 ++++
 .../scala/org/apache/spark/sql/Dataset.scala    | 20 +++++++++++---------
 .../org/apache/spark/sql/DatasetSuite.scala     | 11 +++++++++++
 4 files changed, 29 insertions(+), 10 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/9d7a4740/project/MimaExcludes.scala
----------------------------------------------------------------------
diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala
index 5606155..a201d7f 100644
--- a/project/MimaExcludes.scala
+++ b/project/MimaExcludes.scala
@@ -38,7 +38,9 @@ object MimaExcludes {
   lazy val v21excludes = v20excludes ++ {
     Seq(
       // [SPARK-16199][SQL] Add a method to list the referenced columns in data source Filter
-      ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.sources.Filter.references")
+      ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.sources.Filter.references"),
+      // [SPARK-16853][SQL] Fixes encoder error in DataSet typed select
+      ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.sql.Dataset.select")
     )
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/9d7a4740/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
index 1fac26c..b96b744 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
@@ -169,6 +169,10 @@ object ExpressionEncoder {
       ClassTag(cls))
   }
 
+  // Tuple1
+  def tuple[T](e: ExpressionEncoder[T]): ExpressionEncoder[Tuple1[T]] =
+    tuple(Seq(e)).asInstanceOf[ExpressionEncoder[Tuple1[T]]]
+
   def tuple[T1, T2](
       e1: ExpressionEncoder[T1],
       e2: ExpressionEncoder[T2]): ExpressionEncoder[(T1, T2)] =

http://git-wip-us.apache.org/repos/asf/spark/blob/9d7a4740/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index 8b6443c..306ca77 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -1061,15 +1061,17 @@ class Dataset[T] private[sql](
    * @since 1.6.0
    */
   @Experimental
-  def select[U1: Encoder](c1: TypedColumn[T, U1]): Dataset[U1] = {
-    new Dataset[U1](
-      sparkSession,
-      Project(
-        c1.withInputType(
-          exprEnc.deserializer,
-          logicalPlan.output).named :: Nil,
-        logicalPlan),
-      implicitly[Encoder[U1]])
+  def select[U1](c1: TypedColumn[T, U1]): Dataset[U1] = {
+    implicit val encoder = c1.encoder
+    val project = Project(c1.withInputType(exprEnc.deserializer, logicalPlan.output).named :: Nil,
+      logicalPlan)
+
+    if (encoder.flat) {
+      new Dataset[U1](sparkSession, project, encoder)
+    } else {
+      // Flattens inner fields of U1
+      new Dataset[Tuple1[U1]](sparkSession, project, ExpressionEncoder.tuple(encoder)).map(_._1)
+    }
   }
 
   /**

http://git-wip-us.apache.org/repos/asf/spark/blob/9d7a4740/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
index 7e3b7b6..8a756fd 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
@@ -184,6 +184,17 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
       2, 3, 4)
   }
 
+  test("SPARK-16853: select, case class and tuple") {
+    val ds = Seq(("a", 1), ("b", 2), ("c", 3)).toDS()
+    checkDataset(
+      ds.select(expr("struct(_2, _2)").as[(Int, Int)]): Dataset[(Int, Int)],
+      (1, 1), (2, 2), (3, 3))
+
+    checkDataset(
+      ds.select(expr("named_struct('a', _1, 'b', _2)").as[ClassData]): Dataset[ClassData],
+      ClassData("a", 1), ClassData("b", 2), ClassData("c", 3))
+  }
+
   test("select 2") {
     val ds = Seq(("a", 1), ("b", 2), ("c", 3)).toDS()
     checkDataset(


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org