You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ma...@apache.org on 2016/02/05 23:34:16 UTC

spark git commit: [SPARK-12939][SQL] migrate encoder resolution logic to Analyzer

Repository: spark
Updated Branches:
  refs/heads/master 7b73f1719 -> 1ed354a53


[SPARK-12939][SQL] migrate encoder resolution logic to Analyzer

https://issues.apache.org/jira/browse/SPARK-12939

Now we will catch `ObjectOperator` in `Analyzer` and resolve the `fromRowExpression/deserializer` inside it.  Also update the `MapGroups` and `CoGroup` to pass in `dataAttributes`, so that we can correctly resolve value deserializer(the `child.output` contains both groupking key and values, which may mess things up if they have same-name attribtues). End-to-end tests are added.

follow-ups:

* remove encoders from typed aggregate expression.
* completely remove resolve/bind in `ExpressionEncoder`

Author: Wenchen Fan <we...@databricks.com>

Closes #10852 from cloud-fan/bug.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/1ed354a5
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/1ed354a5
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/1ed354a5

Branch: refs/heads/master
Commit: 1ed354a5362967d904e9513e5a1618676c9c67a6
Parents: 7b73f17
Author: Wenchen Fan <we...@databricks.com>
Authored: Fri Feb 5 14:34:12 2016 -0800
Committer: Michael Armbrust <mi...@databricks.com>
Committed: Fri Feb 5 14:34:12 2016 -0800

----------------------------------------------------------------------
 .../spark/sql/catalyst/analysis/Analyzer.scala  | 82 ++++++++++++++------
 .../catalyst/encoders/ExpressionEncoder.scala   | 50 ++++++------
 .../sql/catalyst/encoders/RowEncoder.scala      |  2 +-
 .../sql/catalyst/optimizer/Optimizer.scala      |  9 ++-
 .../sql/catalyst/plans/logical/object.scala     | 68 +++++++++++-----
 .../encoders/EncoderResolutionSuite.scala       |  8 +-
 .../encoders/ExpressionEncoderSuite.scala       |  5 +-
 .../scala/org/apache/spark/sql/Dataset.scala    |  3 +-
 .../org/apache/spark/sql/GroupedDataset.scala   |  3 +
 .../spark/sql/execution/SparkStrategies.scala   |  9 ++-
 .../apache/spark/sql/execution/objects.scala    | 31 ++++----
 .../org/apache/spark/sql/DatasetSuite.scala     | 64 +++++++++++++--
 12 files changed, 230 insertions(+), 104 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/1ed354a5/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index b59eb12..cb228cf 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -22,6 +22,7 @@ import scala.collection.mutable.ArrayBuffer
 
 import org.apache.spark.sql.AnalysisException
 import org.apache.spark.sql.catalyst.{CatalystConf, ScalaReflection, SimpleCatalystConf}
+import org.apache.spark.sql.catalyst.encoders.OuterScopes
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.expressions.aggregate._
 import org.apache.spark.sql.catalyst.plans._
@@ -457,25 +458,34 @@ class Analyzer(
       // When resolve `SortOrder`s in Sort based on child, don't report errors as
       // we still have chance to resolve it based on its descendants
       case s @ Sort(ordering, global, child) if child.resolved && !s.resolved =>
-        val newOrdering = resolveSortOrders(ordering, child, throws = false)
+        val newOrdering =
+          ordering.map(order => resolveExpression(order, child).asInstanceOf[SortOrder])
         Sort(newOrdering, global, child)
 
       // A special case for Generate, because the output of Generate should not be resolved by
       // ResolveReferences. Attributes in the output will be resolved by ResolveGenerate.
       case g @ Generate(generator, join, outer, qualifier, output, child)
         if child.resolved && !generator.resolved =>
-        val newG = generator transformUp {
-          case u @ UnresolvedAttribute(nameParts) =>
-            withPosition(u) { child.resolve(nameParts, resolver).getOrElse(u) }
-          case UnresolvedExtractValue(child, fieldExpr) =>
-            ExtractValue(child, fieldExpr, resolver)
-        }
+        val newG = resolveExpression(generator, child, throws = true)
         if (newG.fastEquals(generator)) {
           g
         } else {
           Generate(newG.asInstanceOf[Generator], join, outer, qualifier, output, child)
         }
 
+      // A special case for ObjectOperator, because the deserializer expressions in ObjectOperator
+      // should be resolved by their corresponding attributes instead of children's output.
+      case o: ObjectOperator if containsUnresolvedDeserializer(o.deserializers.map(_._1)) =>
+        val deserializerToAttributes = o.deserializers.map {
+          case (deserializer, attributes) => new TreeNodeRef(deserializer) -> attributes
+        }.toMap
+
+        o.transformExpressions {
+          case expr => deserializerToAttributes.get(new TreeNodeRef(expr)).map { attributes =>
+            resolveDeserializer(expr, attributes)
+          }.getOrElse(expr)
+        }
+
       case q: LogicalPlan =>
         logTrace(s"Attempting to resolve ${q.simpleString}")
         q transformExpressionsUp  {
@@ -490,6 +500,32 @@ class Analyzer(
         }
     }
 
+    private def containsUnresolvedDeserializer(exprs: Seq[Expression]): Boolean = {
+      exprs.exists { expr =>
+        !expr.resolved || expr.find(_.isInstanceOf[BoundReference]).isDefined
+      }
+    }
+
+    def resolveDeserializer(
+        deserializer: Expression,
+        attributes: Seq[Attribute]): Expression = {
+      val unbound = deserializer transform {
+        case b: BoundReference => attributes(b.ordinal)
+      }
+
+      resolveExpression(unbound, LocalRelation(attributes), throws = true) transform {
+        case n: NewInstance if n.outerPointer.isEmpty && n.cls.isMemberClass =>
+          val outer = OuterScopes.outerScopes.get(n.cls.getDeclaringClass.getName)
+          if (outer == null) {
+            throw new AnalysisException(
+              s"Unable to generate an encoder for inner class `${n.cls.getName}` without " +
+                "access to the scope that this class was defined in.\n" +
+                "Try moving this class out of its parent class.")
+          }
+          n.copy(outerPointer = Some(Literal.fromObject(outer)))
+      }
+    }
+
     def newAliases(expressions: Seq[NamedExpression]): Seq[NamedExpression] = {
       expressions.map {
         case a: Alias => Alias(a.child, a.name)()
@@ -508,23 +544,20 @@ class Analyzer(
       exprs.exists(_.collect { case _: Star => true }.nonEmpty)
   }
 
-  private def resolveSortOrders(ordering: Seq[SortOrder], plan: LogicalPlan, throws: Boolean) = {
-    ordering.map { order =>
-      // Resolve SortOrder in one round.
-      // If throws == false or the desired attribute doesn't exist
-      // (like try to resolve `a.b` but `a` doesn't exist), fail and return the origin one.
-      // Else, throw exception.
-      try {
-        val newOrder = order transformUp {
-          case u @ UnresolvedAttribute(nameParts) =>
-            plan.resolve(nameParts, resolver).getOrElse(u)
-          case UnresolvedExtractValue(child, fieldName) if child.resolved =>
-            ExtractValue(child, fieldName, resolver)
-        }
-        newOrder.asInstanceOf[SortOrder]
-      } catch {
-        case a: AnalysisException if !throws => order
+  private def resolveExpression(expr: Expression, plan: LogicalPlan, throws: Boolean = false) = {
+    // Resolve expression in one round.
+    // If throws == false or the desired attribute doesn't exist
+    // (like try to resolve `a.b` but `a` doesn't exist), fail and return the origin one.
+    // Else, throw exception.
+    try {
+      expr transformUp {
+        case u @ UnresolvedAttribute(nameParts) =>
+          withPosition(u) { plan.resolve(nameParts, resolver).getOrElse(u) }
+        case UnresolvedExtractValue(child, fieldName) if child.resolved =>
+          ExtractValue(child, fieldName, resolver)
       }
+    } catch {
+      case a: AnalysisException if !throws => expr
     }
   }
 
@@ -619,7 +652,8 @@ class Analyzer(
         ordering: Seq[SortOrder],
         plan: LogicalPlan,
         child: LogicalPlan): (Seq[SortOrder], Seq[Attribute]) = {
-      val newOrdering = resolveSortOrders(ordering, child, throws = false)
+      val newOrdering =
+        ordering.map(order => resolveExpression(order, child).asInstanceOf[SortOrder])
       // Construct a set that contains all of the attributes that we need to evaluate the
       // ordering.
       val requiredAttributes = AttributeSet(newOrdering).filter(_.resolved)

http://git-wip-us.apache.org/repos/asf/spark/blob/1ed354a5/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
index 64832dc..58f6d0e 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
@@ -50,7 +50,7 @@ object ExpressionEncoder {
     val cls = mirror.runtimeClass(typeTag[T].tpe)
     val flat = !classOf[Product].isAssignableFrom(cls)
 
-    val inputObject = BoundReference(0, ScalaReflection.dataTypeFor[T], nullable = true)
+    val inputObject = BoundReference(0, ScalaReflection.dataTypeFor[T], nullable = false)
     val toRowExpression = ScalaReflection.extractorsFor[T](inputObject)
     val fromRowExpression = ScalaReflection.constructorFor[T]
 
@@ -257,12 +257,10 @@ case class ExpressionEncoder[T](
   }
 
   /**
-   * Returns a new copy of this encoder, where the expressions used by `fromRow` are resolved to the
-   * given schema.
+   * Validates `fromRowExpression` to make sure it can be resolved by given schema, and produce
+   * friendly error messages to explain why it fails to resolve if there is something wrong.
    */
-  def resolve(
-      schema: Seq[Attribute],
-      outerScopes: ConcurrentMap[String, AnyRef]): ExpressionEncoder[T] = {
+  def validate(schema: Seq[Attribute]): Unit = {
     def fail(st: StructType, maxOrdinal: Int): Unit = {
       throw new AnalysisException(s"Try to map ${st.simpleString} to Tuple${maxOrdinal + 1}, " +
         "but failed as the number of fields does not line up.\n" +
@@ -270,6 +268,8 @@ case class ExpressionEncoder[T](
         " - Target schema: " + this.schema.simpleString)
     }
 
+    // If this is a tuple encoder or tupled encoder, which means its leaf nodes are all
+    // `BoundReference`, make sure their ordinals are all valid.
     var maxOrdinal = -1
     fromRowExpression.foreach {
       case b: BoundReference => if (b.ordinal > maxOrdinal) maxOrdinal = b.ordinal
@@ -279,6 +279,10 @@ case class ExpressionEncoder[T](
       fail(StructType.fromAttributes(schema), maxOrdinal)
     }
 
+    // If we have nested tuple, the `fromRowExpression` will contains `GetStructField` instead of
+    // `UnresolvedExtractValue`, so we need to check if their ordinals are all valid.
+    // Note that, `BoundReference` contains the expected type, but here we need the actual type, so
+    // we unbound it by the given `schema` and propagate the actual type to `GetStructField`.
     val unbound = fromRowExpression transform {
       case b: BoundReference => schema(b.ordinal)
     }
@@ -299,28 +303,24 @@ case class ExpressionEncoder[T](
           fail(schema, maxOrdinal)
         }
     }
+  }
 
-    val plan = Project(Alias(unbound, "")() :: Nil, LocalRelation(schema))
+  /**
+   * Returns a new copy of this encoder, where the expressions used by `fromRow` are resolved to the
+   * given schema.
+   */
+  def resolve(
+      schema: Seq[Attribute],
+      outerScopes: ConcurrentMap[String, AnyRef]): ExpressionEncoder[T] = {
+    val deserializer = SimpleAnalyzer.ResolveReferences.resolveDeserializer(
+      fromRowExpression, schema)
+
+    // Make a fake plan to wrap the deserializer, so that we can go though the whole analyzer, check
+    // analysis, go through optimizer, etc.
+    val plan = Project(Alias(deserializer, "")() :: Nil, LocalRelation(schema))
     val analyzedPlan = SimpleAnalyzer.execute(plan)
     SimpleAnalyzer.checkAnalysis(analyzedPlan)
-    val optimizedPlan = SimplifyCasts(analyzedPlan)
-
-    // In order to construct instances of inner classes (for example those declared in a REPL cell),
-    // we need an instance of the outer scope.  This rule substitues those outer objects into
-    // expressions that are missing them by looking up the name in the SQLContexts `outerScopes`
-    // registry.
-    copy(fromRowExpression = optimizedPlan.expressions.head.children.head transform {
-      case n: NewInstance if n.outerPointer.isEmpty && n.cls.isMemberClass =>
-        val outer = outerScopes.get(n.cls.getDeclaringClass.getName)
-        if (outer == null) {
-          throw new AnalysisException(
-            s"Unable to generate an encoder for inner class `${n.cls.getName}` without access " +
-            s"to the scope that this class was defined in. " + "" +
-             "Try moving this class out of its parent class.")
-        }
-
-        n.copy(outerPointer = Some(Literal.fromObject(outer)))
-    })
+    copy(fromRowExpression = SimplifyCasts(analyzedPlan).expressions.head.children.head)
   }
 
   /**

http://git-wip-us.apache.org/repos/asf/spark/blob/1ed354a5/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
index 89d40b3..d8f755a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
@@ -154,7 +154,7 @@ object RowEncoder {
       If(
         IsNull(field),
         Literal.create(null, externalDataTypeFor(f.dataType)),
-        constructorFor(BoundReference(i, f.dataType, f.nullable))
+        constructorFor(field)
       )
     }
     CreateExternalRow(fields)

http://git-wip-us.apache.org/repos/asf/spark/blob/1ed354a5/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index a1ac930..902e180 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -119,10 +119,13 @@ object SamplePushDown extends Rule[LogicalPlan] {
  */
 object EliminateSerialization extends Rule[LogicalPlan] {
   def apply(plan: LogicalPlan): LogicalPlan = plan transform {
-    case m @ MapPartitions(_, input, _, child: ObjectOperator)
-        if !input.isInstanceOf[Attribute] && m.input.dataType == child.outputObject.dataType =>
+    case m @ MapPartitions(_, deserializer, _, child: ObjectOperator)
+        if !deserializer.isInstanceOf[Attribute] &&
+          deserializer.dataType == child.outputObject.dataType =>
       val childWithoutSerialization = child.withObjectOutput
-      m.copy(input = childWithoutSerialization.output.head, child = childWithoutSerialization)
+      m.copy(
+        deserializer = childWithoutSerialization.output.head,
+        child = childWithoutSerialization)
   }
 }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/1ed354a5/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala
index 7603480..3f97662 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.plans.logical
 import org.apache.spark.sql.Encoder
 import org.apache.spark.sql.catalyst.encoders._
 import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.types.ObjectType
+import org.apache.spark.sql.types.{ObjectType, StructType}
 
 /**
  * A trait for logical operators that apply user defined functions to domain objects.
@@ -30,6 +30,15 @@ trait ObjectOperator extends LogicalPlan {
   /** The serializer that is used to produce the output of this operator. */
   def serializer: Seq[NamedExpression]
 
+  override def output: Seq[Attribute] = serializer.map(_.toAttribute)
+
+  /**
+   * An [[ObjectOperator]] may have one or more deserializers to convert internal rows to objects.
+   * It must also provide the attributes that are available during the resolution of each
+   * deserializer.
+   */
+  def deserializers: Seq[(Expression, Seq[Attribute])]
+
   /**
    * The object type that is produced by the user defined function. Note that the return type here
    * is the same whether or not the operator is output serialized data.
@@ -44,13 +53,13 @@ trait ObjectOperator extends LogicalPlan {
   def withObjectOutput: LogicalPlan = if (output.head.dataType.isInstanceOf[ObjectType]) {
     this
   } else {
-    withNewSerializer(outputObject)
+    withNewSerializer(outputObject :: Nil)
   }
 
   /** Returns a copy of this operator with a different serializer. */
-  def withNewSerializer(newSerializer: NamedExpression): LogicalPlan = makeCopy {
+  def withNewSerializer(newSerializer: Seq[NamedExpression]): LogicalPlan = makeCopy {
     productIterator.map {
-      case c if c == serializer => newSerializer :: Nil
+      case c if c == serializer => newSerializer
       case other: AnyRef => other
     }.toArray
   }
@@ -70,15 +79,16 @@ object MapPartitions {
 
 /**
  * A relation produced by applying `func` to each partition of the `child`.
- * @param input used to extract the input to `func` from an input row.
+ *
+ * @param deserializer used to extract the input to `func` from an input row.
  * @param serializer use to serialize the output of `func`.
  */
 case class MapPartitions(
     func: Iterator[Any] => Iterator[Any],
-    input: Expression,
+    deserializer: Expression,
     serializer: Seq[NamedExpression],
     child: LogicalPlan) extends UnaryNode with ObjectOperator {
-  override def output: Seq[Attribute] = serializer.map(_.toAttribute)
+  override def deserializers: Seq[(Expression, Seq[Attribute])] = Seq(deserializer -> child.output)
 }
 
 /** Factory for constructing new `AppendColumn` nodes. */
@@ -97,16 +107,21 @@ object AppendColumns {
 /**
  * A relation produced by applying `func` to each partition of the `child`, concatenating the
  * resulting columns at the end of the input row.
- * @param input used to extract the input to `func` from an input row.
+ *
+ * @param deserializer used to extract the input to `func` from an input row.
  * @param serializer use to serialize the output of `func`.
  */
 case class AppendColumns(
     func: Any => Any,
-    input: Expression,
+    deserializer: Expression,
     serializer: Seq[NamedExpression],
     child: LogicalPlan) extends UnaryNode with ObjectOperator {
+
   override def output: Seq[Attribute] = child.output ++ newColumns
+
   def newColumns: Seq[Attribute] = serializer.map(_.toAttribute)
+
+  override def deserializers: Seq[(Expression, Seq[Attribute])] = Seq(deserializer -> child.output)
 }
 
 /** Factory for constructing new `MapGroups` nodes. */
@@ -114,6 +129,7 @@ object MapGroups {
   def apply[K : Encoder, T : Encoder, U : Encoder](
       func: (K, Iterator[T]) => TraversableOnce[U],
       groupingAttributes: Seq[Attribute],
+      dataAttributes: Seq[Attribute],
       child: LogicalPlan): MapGroups = {
     new MapGroups(
       func.asInstanceOf[(Any, Iterator[Any]) => TraversableOnce[Any]],
@@ -121,6 +137,7 @@ object MapGroups {
       encoderFor[T].fromRowExpression,
       encoderFor[U].namedExpressions,
       groupingAttributes,
+      dataAttributes,
       child)
   }
 }
@@ -129,19 +146,22 @@ object MapGroups {
  * Applies func to each unique group in `child`, based on the evaluation of `groupingAttributes`.
  * Func is invoked with an object representation of the grouping key an iterator containing the
  * object representation of all the rows with that key.
- * @param keyObject used to extract the key object for each group.
- * @param input used to extract the items in the iterator from an input row.
+ *
+ * @param keyDeserializer used to extract the key object for each group.
+ * @param valueDeserializer used to extract the items in the iterator from an input row.
  * @param serializer use to serialize the output of `func`.
  */
 case class MapGroups(
     func: (Any, Iterator[Any]) => TraversableOnce[Any],
-    keyObject: Expression,
-    input: Expression,
+    keyDeserializer: Expression,
+    valueDeserializer: Expression,
     serializer: Seq[NamedExpression],
     groupingAttributes: Seq[Attribute],
+    dataAttributes: Seq[Attribute],
     child: LogicalPlan) extends UnaryNode with ObjectOperator {
 
-  def output: Seq[Attribute] = serializer.map(_.toAttribute)
+  override def deserializers: Seq[(Expression, Seq[Attribute])] =
+    Seq(keyDeserializer -> groupingAttributes, valueDeserializer -> dataAttributes)
 }
 
 /** Factory for constructing new `CoGroup` nodes. */
@@ -150,8 +170,12 @@ object CoGroup {
       func: (Key, Iterator[Left], Iterator[Right]) => TraversableOnce[Result],
       leftGroup: Seq[Attribute],
       rightGroup: Seq[Attribute],
+      leftData: Seq[Attribute],
+      rightData: Seq[Attribute],
       left: LogicalPlan,
       right: LogicalPlan): CoGroup = {
+    require(StructType.fromAttributes(leftGroup) == StructType.fromAttributes(rightGroup))
+
     CoGroup(
       func.asInstanceOf[(Any, Iterator[Any], Iterator[Any]) => TraversableOnce[Any]],
       encoderFor[Key].fromRowExpression,
@@ -160,6 +184,8 @@ object CoGroup {
       encoderFor[Result].namedExpressions,
       leftGroup,
       rightGroup,
+      leftData,
+      rightData,
       left,
       right)
   }
@@ -171,15 +197,21 @@ object CoGroup {
  */
 case class CoGroup(
     func: (Any, Iterator[Any], Iterator[Any]) => TraversableOnce[Any],
-    keyObject: Expression,
-    leftObject: Expression,
-    rightObject: Expression,
+    keyDeserializer: Expression,
+    leftDeserializer: Expression,
+    rightDeserializer: Expression,
     serializer: Seq[NamedExpression],
     leftGroup: Seq[Attribute],
     rightGroup: Seq[Attribute],
+    leftAttr: Seq[Attribute],
+    rightAttr: Seq[Attribute],
     left: LogicalPlan,
     right: LogicalPlan) extends BinaryNode with ObjectOperator {
+
   override def producedAttributes: AttributeSet = outputSet
 
-  override def output: Seq[Attribute] = serializer.map(_.toAttribute)
+  override def deserializers: Seq[(Expression, Seq[Attribute])] =
+    // The `leftGroup` and `rightGroup` are guaranteed te be of same schema, so it's safe to resolve
+    // the `keyDeserializer` based on either of them, here we pick the left one.
+    Seq(keyDeserializer -> leftGroup, leftDeserializer -> leftAttr, rightDeserializer -> rightAttr)
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/1ed354a5/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala
index bc36a55..92a68a4 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala
@@ -127,7 +127,7 @@ class EncoderResolutionSuite extends PlanTest {
 
     {
       val attrs = Seq('a.string, 'b.long, 'c.int)
-      assert(intercept[AnalysisException](encoder.resolve(attrs, null)).message ==
+      assert(intercept[AnalysisException](encoder.validate(attrs)).message ==
         "Try to map struct<a:string,b:bigint,c:int> to Tuple2, " +
           "but failed as the number of fields does not line up.\n" +
           " - Input schema: struct<a:string,b:bigint,c:int>\n" +
@@ -136,7 +136,7 @@ class EncoderResolutionSuite extends PlanTest {
 
     {
       val attrs = Seq('a.string)
-      assert(intercept[AnalysisException](encoder.resolve(attrs, null)).message ==
+      assert(intercept[AnalysisException](encoder.validate(attrs)).message ==
         "Try to map struct<a:string> to Tuple2, " +
           "but failed as the number of fields does not line up.\n" +
           " - Input schema: struct<a:string>\n" +
@@ -149,7 +149,7 @@ class EncoderResolutionSuite extends PlanTest {
 
     {
       val attrs = Seq('a.string, 'b.struct('x.long, 'y.string, 'z.int))
-      assert(intercept[AnalysisException](encoder.resolve(attrs, null)).message ==
+      assert(intercept[AnalysisException](encoder.validate(attrs)).message ==
         "Try to map struct<x:bigint,y:string,z:int> to Tuple2, " +
           "but failed as the number of fields does not line up.\n" +
           " - Input schema: struct<a:string,b:struct<x:bigint,y:string,z:int>>\n" +
@@ -158,7 +158,7 @@ class EncoderResolutionSuite extends PlanTest {
 
     {
       val attrs = Seq('a.string, 'b.struct('x.long))
-      assert(intercept[AnalysisException](encoder.resolve(attrs, null)).message ==
+      assert(intercept[AnalysisException](encoder.validate(attrs)).message ==
         "Try to map struct<x:bigint> to Tuple2, " +
           "but failed as the number of fields does not line up.\n" +
           " - Input schema: struct<a:string,b:struct<x:bigint>>\n" +

http://git-wip-us.apache.org/repos/asf/spark/blob/1ed354a5/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
index 88c558d..e00060f 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
@@ -19,13 +19,10 @@ package org.apache.spark.sql.catalyst.encoders
 
 import java.sql.{Date, Timestamp}
 import java.util.Arrays
-import java.util.concurrent.ConcurrentMap
 
 import scala.collection.mutable.ArrayBuffer
 import scala.reflect.runtime.universe.TypeTag
 
-import com.google.common.collect.MapMaker
-
 import org.apache.spark.SparkFunSuite
 import org.apache.spark.sql.Encoders
 import org.apache.spark.sql.catalyst.{OptionalData, PrimitiveData}
@@ -78,7 +75,7 @@ class JavaSerializable(val value: Int) extends Serializable {
 }
 
 class ExpressionEncoderSuite extends SparkFunSuite {
-  OuterScopes.outerScopes.put(getClass.getName, this)
+  OuterScopes.addOuterScope(this)
 
   implicit def encoder[T : TypeTag]: ExpressionEncoder[T] = ExpressionEncoder()
 

http://git-wip-us.apache.org/repos/asf/spark/blob/1ed354a5/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index f182270..3787632 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -74,6 +74,7 @@ class Dataset[T] private[sql](
    * same object type (that will be possibly resolved to a different schema).
    */
   private[sql] implicit val unresolvedTEncoder: ExpressionEncoder[T] = encoderFor(tEncoder)
+  unresolvedTEncoder.validate(logicalPlan.output)
 
   /** The encoder for this [[Dataset]] that has been resolved to its output schema. */
   private[sql] val resolvedTEncoder: ExpressionEncoder[T] =
@@ -85,7 +86,7 @@ class Dataset[T] private[sql](
    */
   private[sql] val boundTEncoder = resolvedTEncoder.bind(logicalPlan.output)
 
-  private implicit def classTag = resolvedTEncoder.clsTag
+  private implicit def classTag = unresolvedTEncoder.clsTag
 
   private[sql] def this(sqlContext: SQLContext, plan: LogicalPlan)(implicit encoder: Encoder[T]) =
     this(sqlContext, new QueryExecution(sqlContext, plan), encoder)

http://git-wip-us.apache.org/repos/asf/spark/blob/1ed354a5/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala
index b3f8284..c0e28f2 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala
@@ -116,6 +116,7 @@ class GroupedDataset[K, V] private[sql](
       MapGroups(
         f,
         groupingAttributes,
+        dataAttributes,
         logicalPlan))
   }
 
@@ -310,6 +311,8 @@ class GroupedDataset[K, V] private[sql](
         f,
         this.groupingAttributes,
         other.groupingAttributes,
+        this.dataAttributes,
+        other.dataAttributes,
         this.logicalPlan,
         other.logicalPlan))
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/1ed354a5/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index 9293e55..830bb01 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -306,11 +306,12 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
         execution.MapPartitions(f, in, out, planLater(child)) :: Nil
       case logical.AppendColumns(f, in, out, child) =>
         execution.AppendColumns(f, in, out, planLater(child)) :: Nil
-      case logical.MapGroups(f, key, in, out, grouping, child) =>
-        execution.MapGroups(f, key, in, out, grouping, planLater(child)) :: Nil
-      case logical.CoGroup(f, keyObj, lObj, rObj, out, lGroup, rGroup, left, right) =>
+      case logical.MapGroups(f, key, in, out, grouping, data, child) =>
+        execution.MapGroups(f, key, in, out, grouping, data, planLater(child)) :: Nil
+      case logical.CoGroup(f, keyObj, lObj, rObj, out, lGroup, rGroup, lAttr, rAttr, left, right) =>
         execution.CoGroup(
-          f, keyObj, lObj, rObj, out, lGroup, rGroup, planLater(left), planLater(right)) :: Nil
+          f, keyObj, lObj, rObj, out, lGroup, rGroup, lAttr, rAttr,
+          planLater(left), planLater(right)) :: Nil
 
       case logical.Repartition(numPartitions, shuffle, child) =>
         if (shuffle) {

http://git-wip-us.apache.org/repos/asf/spark/blob/1ed354a5/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala
index 2acca17..582dda8 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala
@@ -53,14 +53,14 @@ trait ObjectOperator extends SparkPlan {
  */
 case class MapPartitions(
     func: Iterator[Any] => Iterator[Any],
-    input: Expression,
+    deserializer: Expression,
     serializer: Seq[NamedExpression],
     child: SparkPlan) extends UnaryNode with ObjectOperator {
   override def output: Seq[Attribute] = serializer.map(_.toAttribute)
 
   override protected def doExecute(): RDD[InternalRow] = {
     child.execute().mapPartitionsInternal { iter =>
-      val getObject = generateToObject(input, child.output)
+      val getObject = generateToObject(deserializer, child.output)
       val outputObject = generateToRow(serializer)
       func(iter.map(getObject)).map(outputObject)
     }
@@ -72,7 +72,7 @@ case class MapPartitions(
  */
 case class AppendColumns(
     func: Any => Any,
-    input: Expression,
+    deserializer: Expression,
     serializer: Seq[NamedExpression],
     child: SparkPlan) extends UnaryNode with ObjectOperator {
 
@@ -82,7 +82,7 @@ case class AppendColumns(
 
   override protected def doExecute(): RDD[InternalRow] = {
     child.execute().mapPartitionsInternal { iter =>
-      val getObject = generateToObject(input, child.output)
+      val getObject = generateToObject(deserializer, child.output)
       val combiner = GenerateUnsafeRowJoiner.create(child.schema, newColumnSchema)
       val outputObject = generateToRow(serializer)
 
@@ -103,10 +103,11 @@ case class AppendColumns(
  */
 case class MapGroups(
     func: (Any, Iterator[Any]) => TraversableOnce[Any],
-    keyObject: Expression,
-    input: Expression,
+    keyDeserializer: Expression,
+    valueDeserializer: Expression,
     serializer: Seq[NamedExpression],
     groupingAttributes: Seq[Attribute],
+    dataAttributes: Seq[Attribute],
     child: SparkPlan) extends UnaryNode with ObjectOperator {
 
   override def output: Seq[Attribute] = serializer.map(_.toAttribute)
@@ -121,8 +122,8 @@ case class MapGroups(
     child.execute().mapPartitionsInternal { iter =>
       val grouped = GroupedIterator(iter, groupingAttributes, child.output)
 
-      val getKey = generateToObject(keyObject, groupingAttributes)
-      val getValue = generateToObject(input, child.output)
+      val getKey = generateToObject(keyDeserializer, groupingAttributes)
+      val getValue = generateToObject(valueDeserializer, dataAttributes)
       val outputObject = generateToRow(serializer)
 
       grouped.flatMap { case (key, rowIter) =>
@@ -142,12 +143,14 @@ case class MapGroups(
  */
 case class CoGroup(
     func: (Any, Iterator[Any], Iterator[Any]) => TraversableOnce[Any],
-    keyObject: Expression,
-    leftObject: Expression,
-    rightObject: Expression,
+    keyDeserializer: Expression,
+    leftDeserializer: Expression,
+    rightDeserializer: Expression,
     serializer: Seq[NamedExpression],
     leftGroup: Seq[Attribute],
     rightGroup: Seq[Attribute],
+    leftAttr: Seq[Attribute],
+    rightAttr: Seq[Attribute],
     left: SparkPlan,
     right: SparkPlan) extends BinaryNode with ObjectOperator {
 
@@ -164,9 +167,9 @@ case class CoGroup(
       val leftGrouped = GroupedIterator(leftData, leftGroup, left.output)
       val rightGrouped = GroupedIterator(rightData, rightGroup, right.output)
 
-      val getKey = generateToObject(keyObject, leftGroup)
-      val getLeft = generateToObject(leftObject, left.output)
-      val getRight = generateToObject(rightObject, right.output)
+      val getKey = generateToObject(keyDeserializer, leftGroup)
+      val getLeft = generateToObject(leftDeserializer, leftAttr)
+      val getRight = generateToObject(rightDeserializer, rightAttr)
       val outputObject = generateToRow(serializer)
 
       new CoGroupedIterator(leftGrouped, rightGrouped, leftGroup).flatMap {

http://git-wip-us.apache.org/repos/asf/spark/blob/1ed354a5/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
index b69bb21..374f432 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
@@ -22,6 +22,7 @@ import java.sql.{Date, Timestamp}
 
 import scala.language.postfixOps
 
+import org.apache.spark.sql.catalyst.encoders.OuterScopes
 import org.apache.spark.sql.functions._
 import org.apache.spark.sql.test.SharedSQLContext
 import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType}
@@ -45,13 +46,12 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
       1, 1, 1)
   }
 
-
   test("SPARK-12404: Datatype Helper Serializablity") {
     val ds = sparkContext.parallelize((
-          new Timestamp(0),
-          new Date(0),
-          java.math.BigDecimal.valueOf(1),
-          scala.math.BigDecimal(1)) :: Nil).toDS()
+      new Timestamp(0),
+      new Date(0),
+      java.math.BigDecimal.valueOf(1),
+      scala.math.BigDecimal(1)) :: Nil).toDS()
 
     ds.collect()
   }
@@ -523,7 +523,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
   test("verify mismatching field names fail with a good error") {
     val ds = Seq(ClassData("a", 1)).toDS()
     val e = intercept[AnalysisException] {
-      ds.as[ClassData2].collect()
+      ds.as[ClassData2]
     }
     assert(e.getMessage.contains("cannot resolve 'c' given input columns: [a, b]"), e.getMessage)
   }
@@ -567,6 +567,58 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
     checkAnswer(ds1, DeepNestedStruct(NestedStruct(null)))
     checkAnswer(ds1.toDF(), Row(Row(null)))
   }
+
+  test("support inner class in Dataset") {
+    val outer = new OuterClass
+    OuterScopes.addOuterScope(outer)
+    val ds = Seq(outer.InnerClass("1"), outer.InnerClass("2")).toDS()
+    checkAnswer(ds.map(_.a), "1", "2")
+  }
+
+  test("grouping key and grouped value has field with same name") {
+    val ds = Seq(ClassData("a", 1), ClassData("a", 2)).toDS()
+    val agged = ds.groupBy(d => ClassNullableData(d.a, null)).mapGroups {
+      case (key, values) => key.a + values.map(_.b).sum
+    }
+
+    checkAnswer(agged, "a3")
+  }
+
+  test("cogroup's left and right side has field with same name") {
+    val left = Seq(ClassData("a", 1), ClassData("b", 2)).toDS()
+    val right = Seq(ClassNullableData("a", 3), ClassNullableData("b", 4)).toDS()
+    val cogrouped = left.groupBy(_.a).cogroup(right.groupBy(_.a)) {
+      case (key, lData, rData) => Iterator(key + lData.map(_.b).sum + rData.map(_.b.toInt).sum)
+    }
+
+    checkAnswer(cogrouped, "a13", "b24")
+  }
+
+  test("give nice error message when the real number of fields doesn't match encoder schema") {
+    val ds = Seq(ClassData("a", 1), ClassData("b", 2)).toDS()
+
+    val message = intercept[AnalysisException] {
+      ds.as[(String, Int, Long)]
+    }.message
+    assert(message ==
+      "Try to map struct<a:string,b:int> to Tuple3, " +
+        "but failed as the number of fields does not line up.\n" +
+        " - Input schema: struct<a:string,b:int>\n" +
+        " - Target schema: struct<_1:string,_2:int,_3:bigint>")
+
+    val message2 = intercept[AnalysisException] {
+      ds.as[Tuple1[String]]
+    }.message
+    assert(message2 ==
+      "Try to map struct<a:string,b:int> to Tuple1, " +
+        "but failed as the number of fields does not line up.\n" +
+        " - Input schema: struct<a:string,b:int>\n" +
+        " - Target schema: struct<_1:string>")
+  }
+}
+
+class OuterClass extends Serializable {
+  case class InnerClass(a: String)
 }
 
 case class ClassData(a: String, b: Int)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org