You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by yh...@apache.org on 2015/10/27 21:29:02 UTC

[2/2] spark git commit: [SPARK-11347] [SQL] Support for joinWith in Datasets

[SPARK-11347] [SQL] Support for joinWith in Datasets

This PR adds a new operation `joinWith` to a `Dataset`, which returns a `Tuple` for each pair where a given `condition` evaluates to true.

```scala
case class ClassData(a: String, b: Int)

val ds1 = Seq(ClassData("a", 1), ClassData("b", 2)).toDS()
val ds2 = Seq(("a", 1), ("b", 2)).toDS()

> ds1.joinWith(ds2, $"_1" === $"a").collect()
res0: Array((ClassData("a", 1), ("a", 1)), (ClassData("b", 2), ("b", 2)))
```

This operation is similar to the relation `join` function with one important difference in the result schema. Since `joinWith` preserves objects present on either side of the join, the result schema is similarly nested into a tuple under the column names `_1` and `_2`.

This type of join can be useful both for preserving type-safety with the original object types as well as working with relational data where either side of the join has column names in common.

## Required Changes to Encoders
In the process of working on this patch, several deficiencies to the way that we were handling encoders were discovered.  Specifically, it turned out to be very difficult to `rebind` the non-expression based encoders to extract the nested objects from the results of joins (and also typed selects that return tuples).

As a result the following changes were made.
 - `ClassEncoder` has been renamed to `ExpressionEncoder` and has been improved to also handle primitive types.  Additionally, it is now possible to take arbitrary expression encoders and rewrite them into a single encoder that returns a tuple.
 - All internal operations on `Dataset`s now require an `ExpressionEncoder`.  If the users tries to pass a non-`ExpressionEncoder` in, an error will be thrown.  We can relax this requirement in the future by constructing a wrapper class that uses expressions to project the row to the expected schema, shielding the users code from the required remapping.  This will give us a nice balance where we don't force user encoders to understand attribute references and binding, but still allow our native encoder to leverage runtime code generation to construct specific encoders for a given schema that avoid an extra remapping step.
 - Additionally, the semantics for different types of objects are now better defined.  As stated in the `ExpressionEncoder` scaladoc:
  - Classes will have their sub fields extracted by name using `UnresolvedAttribute` expressions
  and `UnresolvedExtractValue` expressions.
  - Tuples will have their subfields extracted by position using `BoundReference` expressions.
  - Primitives will have their values extracted from the first ordinal with a schema that defaults
  to the name `value`.
 - Finally, the binding lifecycle for `Encoders` has now been unified across the codebase.  Encoders are now `resolved` to the appropriate schema in the constructor of `Dataset`.  This process replaces an unresolved expressions with concrete `AttributeReference` expressions.  Binding then happens on demand, when an encoder is going to be used to construct an object.  This closely mirrors the lifecycle for standard expressions when executing normal SQL or `DataFrame` queries.

Author: Michael Armbrust <mi...@databricks.com>

Closes #9300 from marmbrus/datasets-tuples.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/5a5f6590
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/5a5f6590
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/5a5f6590

Branch: refs/heads/master
Commit: 5a5f65905a202e59bc85170b01c57a883718ddf6
Parents: 3bdbbc6
Author: Michael Armbrust <mi...@databricks.com>
Authored: Tue Oct 27 13:28:52 2015 -0700
Committer: Yin Huai <yh...@databricks.com>
Committed: Tue Oct 27 13:28:52 2015 -0700

----------------------------------------------------------------------
 .../spark/sql/catalyst/ScalaReflection.scala    |  43 ++-
 .../sql/catalyst/encoders/ClassEncoder.scala    | 101 -------
 .../spark/sql/catalyst/encoders/Encoder.scala   |  38 +--
 .../catalyst/encoders/ExpressionEncoder.scala   | 217 ++++++++++++++
 .../sql/catalyst/encoders/ProductEncoder.scala  |  47 ---
 .../sql/catalyst/encoders/RowEncoder.scala      |   5 +-
 .../spark/sql/catalyst/encoders/package.scala   |  26 ++
 .../sql/catalyst/encoders/primitiveTypes.scala  | 100 -------
 .../spark/sql/catalyst/encoders/tuples.scala    | 173 -----------
 .../catalyst/plans/logical/basicOperators.scala |  28 +-
 .../encoders/ExpressionEncoderSuite.scala       | 291 +++++++++++++++++++
 .../encoders/PrimitiveEncoderSuite.scala        |  43 ---
 .../catalyst/encoders/ProductEncoderSuite.scala | 282 ------------------
 .../scala/org/apache/spark/sql/DataFrame.scala  |   2 +-
 .../scala/org/apache/spark/sql/Dataset.scala    | 190 +++++++-----
 .../scala/org/apache/spark/sql/SQLContext.scala |   4 +-
 .../org/apache/spark/sql/SQLImplicits.scala     |  13 +-
 .../spark/sql/execution/basicOperators.scala    |  16 +-
 .../org/apache/spark/sql/DatasetSuite.scala     |  89 +++++-
 .../scala/org/apache/spark/sql/QueryTest.scala  |  44 ++-
 20 files changed, 850 insertions(+), 902 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/5a5f6590/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
index c25161e..9cbb7c2 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
@@ -146,6 +146,10 @@ trait ScalaReflection {
    * row with a compatible schema.  Fields of the row will be extracted using UnresolvedAttributes
    * of the same name as the constructor arguments.  Nested classes will have their fields accessed
    * using UnresolvedExtractValue.
+   *
+   * When used on a primitive type, the constructor will instead default to extracting the value
+   * from ordinal 0 (since there are no names to map to).  The actual location can be moved by
+   * calling unbind/bind with a new schema.
    */
   def constructorFor[T : TypeTag]: Expression = constructorFor(typeOf[T], None)
 
@@ -159,8 +163,14 @@ trait ScalaReflection {
         .map(p => UnresolvedExtractValue(p, expressions.Literal(part)))
         .getOrElse(UnresolvedAttribute(part))
 
+    /** Returns the current path with a field at ordinal extracted. */
+    def addToPathOrdinal(ordinal: Int, dataType: DataType) =
+      path
+        .map(p => GetStructField(p, StructField(s"_$ordinal", dataType), ordinal))
+        .getOrElse(BoundReference(ordinal, dataType, false))
+
     /** Returns the current path or throws an error. */
-    def getPath = path.getOrElse(sys.error("Constructors must start at a class type"))
+    def getPath = path.getOrElse(BoundReference(0, dataTypeFor(tpe), true))
 
     tpe match {
       case t if !dataTypeFor(t).isInstanceOf[ObjectType] =>
@@ -387,12 +397,17 @@ trait ScalaReflection {
         val className: String = t.erasure.typeSymbol.asClass.fullName
         val cls = Utils.classForName(className)
 
-        val arguments = params.head.map { p =>
+        val arguments = params.head.zipWithIndex.map { case (p, i) =>
           val fieldName = p.name.toString
           val fieldType = p.typeSignature.substituteTypes(formalTypeArgs, actualTypeArgs)
-          val dataType = dataTypeFor(fieldType)
+          val dataType = schemaFor(fieldType).dataType
 
-          constructorFor(fieldType, Some(addToPath(fieldName)))
+          // For tuples, we based grab the inner fields by ordinal instead of name.
+          if (className startsWith "scala.Tuple") {
+            constructorFor(fieldType, Some(addToPathOrdinal(i, dataType)))
+          } else {
+            constructorFor(fieldType, Some(addToPath(fieldName)))
+          }
         }
 
         val newInstance = NewInstance(cls, arguments, propagateNull = false, ObjectType(cls))
@@ -413,7 +428,10 @@ trait ScalaReflection {
   /** Returns expressions for extracting all the fields from the given type. */
   def extractorsFor[T : TypeTag](inputObject: Expression): CreateNamedStruct = {
     ScalaReflectionLock.synchronized {
-      extractorFor(inputObject, typeTag[T].tpe).asInstanceOf[CreateNamedStruct]
+      extractorFor(inputObject, typeTag[T].tpe) match {
+        case s: CreateNamedStruct => s
+        case o => CreateNamedStruct(expressions.Literal("value") :: o :: Nil)
+      }
     }
   }
 
@@ -602,6 +620,21 @@ trait ScalaReflection {
         case t if t <:< localTypeOf[java.lang.Boolean] =>
           Invoke(inputObject, "booleanValue", BooleanType)
 
+        case t if t <:< definitions.IntTpe =>
+          BoundReference(0, IntegerType, false)
+        case t if t <:< definitions.LongTpe =>
+          BoundReference(0, LongType, false)
+        case t if t <:< definitions.DoubleTpe =>
+          BoundReference(0, DoubleType, false)
+        case t if t <:< definitions.FloatTpe =>
+          BoundReference(0, FloatType, false)
+        case t if t <:< definitions.ShortTpe =>
+          BoundReference(0, ShortType, false)
+        case t if t <:< definitions.ByteTpe =>
+          BoundReference(0, ByteType, false)
+        case t if t <:< definitions.BooleanTpe =>
+          BoundReference(0, BooleanType, false)
+
         case other =>
           throw new UnsupportedOperationException(s"Extractor for type $other is not supported")
       }

http://git-wip-us.apache.org/repos/asf/spark/blob/5a5f6590/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ClassEncoder.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ClassEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ClassEncoder.scala
deleted file mode 100644
index b484b8f..0000000
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ClassEncoder.scala
+++ /dev/null
@@ -1,101 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.catalyst.encoders
-
-import scala.reflect.ClassTag
-
-import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, SimpleAnalyzer}
-import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateSafeProjection, GenerateUnsafeProjection}
-import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Project}
-import org.apache.spark.sql.types.{ObjectType, StructType}
-
-/**
- * A generic encoder for JVM objects.
- *
- * @param schema The schema after converting `T` to a Spark SQL row.
- * @param extractExpressions A set of expressions, one for each top-level field that can be used to
- *                           extract the values from a raw object.
- * @param clsTag A classtag for `T`.
- */
-case class ClassEncoder[T](
-    schema: StructType,
-    extractExpressions: Seq[Expression],
-    constructExpression: Expression,
-    clsTag: ClassTag[T])
-  extends Encoder[T] {
-
-  @transient
-  private lazy val extractProjection = GenerateUnsafeProjection.generate(extractExpressions)
-  private val inputRow = new GenericMutableRow(1)
-
-  @transient
-  private lazy val constructProjection = GenerateSafeProjection.generate(constructExpression :: Nil)
-  private val dataType = ObjectType(clsTag.runtimeClass)
-
-  override def toRow(t: T): InternalRow = {
-    inputRow(0) = t
-    extractProjection(inputRow)
-  }
-
-  override def fromRow(row: InternalRow): T = {
-    constructProjection(row).get(0, dataType).asInstanceOf[T]
-  }
-
-  override def bind(schema: Seq[Attribute]): ClassEncoder[T] = {
-    val plan = Project(Alias(constructExpression, "object")() :: Nil, LocalRelation(schema))
-    val analyzedPlan = SimpleAnalyzer.execute(plan)
-    val resolvedExpression = analyzedPlan.expressions.head.children.head
-    val boundExpression = BindReferences.bindReference(resolvedExpression, schema)
-
-    copy(constructExpression = boundExpression)
-  }
-
-  override def rebind(oldSchema: Seq[Attribute], newSchema: Seq[Attribute]): ClassEncoder[T] = {
-    val positionToAttribute = AttributeMap.toIndex(oldSchema)
-    val attributeToNewPosition = AttributeMap.byIndex(newSchema)
-    copy(constructExpression = constructExpression transform {
-      case r: BoundReference =>
-        r.copy(ordinal = attributeToNewPosition(positionToAttribute(r.ordinal)))
-    })
-  }
-
-  override def bindOrdinals(schema: Seq[Attribute]): ClassEncoder[T] = {
-    var remaining = schema
-    copy(constructExpression = constructExpression transform {
-      case u: UnresolvedAttribute =>
-        val pos = remaining.head
-        remaining = remaining.drop(1)
-        pos
-    })
-  }
-
-  protected val attrs = extractExpressions.map(_.collect {
-    case a: Attribute => s"#${a.exprId}"
-    case b: BoundReference => s"[${b.ordinal}]"
-  }.headOption.getOrElse(""))
-
-
-  protected val schemaString =
-    schema
-      .zip(attrs)
-      .map { case(f, a) => s"${f.name}$a: ${f.dataType.simpleString}"}.mkString(", ")
-
-  override def toString: String = s"class[$schemaString]"
-}

http://git-wip-us.apache.org/repos/asf/spark/blob/5a5f6590/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/Encoder.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/Encoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/Encoder.scala
index efb872d..329a132 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/Encoder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/Encoder.scala
@@ -18,10 +18,9 @@
 package org.apache.spark.sql.catalyst.encoders
 
 
+
 import scala.reflect.ClassTag
 
-import org.apache.spark.sql.catalyst.expressions.Attribute
-import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.types.StructType
 
 /**
@@ -30,44 +29,11 @@ import org.apache.spark.sql.types.StructType
  * Encoders are not intended to be thread-safe and thus they are allow to avoid internal locking
  * and reuse internal buffers to improve performance.
  */
-trait Encoder[T] {
+trait Encoder[T] extends Serializable {
 
   /** Returns the schema of encoding this type of object as a Row. */
   def schema: StructType
 
   /** A ClassTag that can be used to construct and Array to contain a collection of `T`. */
   def clsTag: ClassTag[T]
-
-  /**
-   * Returns an encoded version of `t` as a Spark SQL row.  Note that multiple calls to
-   * toRow are allowed to return the same actual [[InternalRow]] object.  Thus, the caller should
-   * copy the result before making another call if required.
-   */
-  def toRow(t: T): InternalRow
-
-  /**
-   * Returns an object of type `T`, extracting the required values from the provided row.  Note that
-   * you must `bind` an encoder to a specific schema before you can call this function.
-   */
-  def fromRow(row: InternalRow): T
-
-  /**
-   * Returns a new copy of this encoder, where the expressions used by `fromRow` are bound to the
-   * given schema.
-   */
-  def bind(schema: Seq[Attribute]): Encoder[T]
-
-  /**
-   * Binds this encoder to the given schema positionally.  In this binding, the first reference to
-   * any input is mapped to `schema(0)`, and so on for each input that is encountered.
-   */
-  def bindOrdinals(schema: Seq[Attribute]): Encoder[T]
-
-  /**
-   * Given an encoder that has already been bound to a given schema, returns a new encoder that
-   * where the positions are mapped from `oldSchema` to `newSchema`.  This can be used, for example,
-   * when you are trying to use an encoder on grouping keys that were orriginally part of a larger
-   * row, but now you have projected out only the key expressions.
-   */
-  def rebind(oldSchema: Seq[Attribute], newSchema: Seq[Attribute]): Encoder[T]
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/5a5f6590/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
new file mode 100644
index 0000000..c287aeb
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
@@ -0,0 +1,217 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.encoders
+
+import org.apache.spark.sql.catalyst.analysis.{SimpleAnalyzer, UnresolvedExtractValue, UnresolvedAttribute}
+import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Project}
+import org.apache.spark.util.Utils
+
+import scala.reflect.ClassTag
+import scala.reflect.runtime.universe.{typeTag, TypeTag}
+
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateSafeProjection, GenerateUnsafeProjection}
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.ScalaReflection
+import org.apache.spark.sql.types.{StructField, DataType, ObjectType, StructType}
+
+/**
+ * A factory for constructing encoders that convert objects and primitves to and from the
+ * internal row format using catalyst expressions and code generation.  By default, the
+ * expressions used to retrieve values from an input row when producing an object will be created as
+ * follows:
+ *  - Classes will have their sub fields extracted by name using [[UnresolvedAttribute]] expressions
+ *    and [[UnresolvedExtractValue]] expressions.
+ *  - Tuples will have their subfields extracted by position using [[BoundReference]] expressions.
+ *  - Primitives will have their values extracted from the first ordinal with a schema that defaults
+ *    to the name `value`.
+ */
+object ExpressionEncoder {
+  def apply[T : TypeTag](flat: Boolean = false): ExpressionEncoder[T] = {
+    // We convert the not-serializable TypeTag into StructType and ClassTag.
+    val mirror = typeTag[T].mirror
+    val cls = mirror.runtimeClass(typeTag[T].tpe)
+
+    val inputObject = BoundReference(0, ObjectType(cls), nullable = true)
+    val extractExpression = ScalaReflection.extractorsFor[T](inputObject)
+    val constructExpression = ScalaReflection.constructorFor[T]
+
+    new ExpressionEncoder[T](
+      extractExpression.dataType,
+      flat,
+      extractExpression.flatten,
+      constructExpression,
+      ClassTag[T](cls))
+  }
+
+  /**
+   * Given a set of N encoders, constructs a new encoder that produce objects as items in an
+   * N-tuple.  Note that these encoders should first be bound correctly to the combined input
+   * schema.
+   */
+  def tuple(encoders: Seq[ExpressionEncoder[_]]): ExpressionEncoder[_] = {
+    val schema =
+      StructType(
+        encoders.zipWithIndex.map { case (e, i) => StructField(s"_${i + 1}", e.schema)})
+    val cls = Utils.getContextOrSparkClassLoader.loadClass(s"scala.Tuple${encoders.size}")
+    val extractExpressions = encoders.map {
+      case e if e.flat => e.extractExpressions.head
+      case other => CreateStruct(other.extractExpressions)
+    }
+    val constructExpression =
+      NewInstance(cls, encoders.map(_.constructExpression), false, ObjectType(cls))
+
+    new ExpressionEncoder[Any](
+      schema,
+      false,
+      extractExpressions,
+      constructExpression,
+      ClassTag.apply(cls))
+  }
+
+  /** A helper for producing encoders of Tuple2 from other encoders. */
+  def tuple[T1, T2](
+      e1: ExpressionEncoder[T1],
+      e2: ExpressionEncoder[T2]): ExpressionEncoder[(T1, T2)] =
+    tuple(e1 :: e2 :: Nil).asInstanceOf[ExpressionEncoder[(T1, T2)]]
+}
+
+/**
+ * A generic encoder for JVM objects.
+ *
+ * @param schema The schema after converting `T` to a Spark SQL row.
+ * @param extractExpressions A set of expressions, one for each top-level field that can be used to
+ *                           extract the values from a raw object.
+ * @param clsTag A classtag for `T`.
+ */
+case class ExpressionEncoder[T](
+    schema: StructType,
+    flat: Boolean,
+    extractExpressions: Seq[Expression],
+    constructExpression: Expression,
+    clsTag: ClassTag[T])
+  extends Encoder[T] {
+
+  if (flat) require(extractExpressions.size == 1)
+
+  @transient
+  private lazy val extractProjection = GenerateUnsafeProjection.generate(extractExpressions)
+  private val inputRow = new GenericMutableRow(1)
+
+  @transient
+  private lazy val constructProjection = GenerateSafeProjection.generate(constructExpression :: Nil)
+
+  /**
+   * Returns an encoded version of `t` as a Spark SQL row.  Note that multiple calls to
+   * toRow are allowed to return the same actual [[InternalRow]] object.  Thus, the caller should
+   * copy the result before making another call if required.
+   */
+  def toRow(t: T): InternalRow = {
+    inputRow(0) = t
+    extractProjection(inputRow)
+  }
+
+  /**
+   * Returns an object of type `T`, extracting the required values from the provided row.  Note that
+   * you must `resolve` and `bind` an encoder to a specific schema before you can call this
+   * function.
+   */
+  def fromRow(row: InternalRow): T = try {
+    constructProjection(row).get(0, ObjectType(clsTag.runtimeClass)).asInstanceOf[T]
+  } catch {
+    case e: Exception =>
+      throw new RuntimeException(s"Error while decoding: $e\n${constructExpression.treeString}", e)
+  }
+
+  /**
+   * Returns a new copy of this encoder, where the expressions used by `fromRow` are resolved to the
+   * given schema.
+   */
+  def resolve(schema: Seq[Attribute]): ExpressionEncoder[T] = {
+    val plan = Project(Alias(constructExpression, "")() :: Nil, LocalRelation(schema))
+    val analyzedPlan = SimpleAnalyzer.execute(plan)
+    copy(constructExpression = analyzedPlan.expressions.head.children.head)
+  }
+
+  /**
+   * Returns a copy of this encoder where the expressions used to construct an object from an input
+   * row have been bound to the ordinals of the given schema.  Note that you need to first call
+   * resolve before bind.
+   */
+  def bind(schema: Seq[Attribute]): ExpressionEncoder[T] = {
+    copy(constructExpression = BindReferences.bindReference(constructExpression, schema))
+  }
+
+  /**
+   * Replaces any bound references in the schema with the attributes at the corresponding ordinal
+   * in the provided schema.  This can be used to "relocate" a given encoder to pull values from
+   * a different schema than it was initially bound to.  It can also be used to assign attributes
+   * to ordinal based extraction (i.e. because the input data was a tuple).
+   */
+  def unbind(schema: Seq[Attribute]): ExpressionEncoder[T] = {
+    val positionToAttribute = AttributeMap.toIndex(schema)
+    copy(constructExpression = constructExpression transform {
+      case b: BoundReference => positionToAttribute(b.ordinal)
+    })
+  }
+
+  /**
+   * Given an encoder that has already been bound to a given schema, returns a new encoder
+   * where the positions are mapped from `oldSchema` to `newSchema`.  This can be used, for example,
+   * when you are trying to use an encoder on grouping keys that were originally part of a larger
+   * row, but now you have projected out only the key expressions.
+   */
+  def rebind(oldSchema: Seq[Attribute], newSchema: Seq[Attribute]): ExpressionEncoder[T] = {
+    val positionToAttribute = AttributeMap.toIndex(oldSchema)
+    val attributeToNewPosition = AttributeMap.byIndex(newSchema)
+    copy(constructExpression = constructExpression transform {
+      case r: BoundReference =>
+        r.copy(ordinal = attributeToNewPosition(positionToAttribute(r.ordinal)))
+    })
+  }
+
+  /**
+   * Returns a copy of this encoder where the expressions used to create an object given an
+   * input row have been modified to pull the object out from a nested struct, instead of the
+   * top level fields.
+   */
+  def nested(input: Expression = BoundReference(0, schema, true)): ExpressionEncoder[T] = {
+    copy(constructExpression = constructExpression transform {
+      case u: Attribute if u != input =>
+        UnresolvedExtractValue(input, Literal(u.name))
+      case b: BoundReference if b != input =>
+        GetStructField(
+          input,
+          StructField(s"i[${b.ordinal}]", b.dataType),
+          b.ordinal)
+    })
+  }
+
+  protected val attrs = extractExpressions.flatMap(_.collect {
+    case _: UnresolvedAttribute => ""
+    case a: Attribute => s"#${a.exprId}"
+    case b: BoundReference => s"[${b.ordinal}]"
+  })
+
+  protected val schemaString =
+    schema
+      .zip(attrs)
+      .map { case(f, a) => s"${f.name}$a: ${f.dataType.simpleString}"}.mkString(", ")
+
+  override def toString: String = s"class[$schemaString]"
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/5a5f6590/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoder.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoder.scala
deleted file mode 100644
index 34f5e6c..0000000
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoder.scala
+++ /dev/null
@@ -1,47 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.catalyst.encoders
-
-import scala.reflect.ClassTag
-import scala.reflect.runtime.universe.{typeTag, TypeTag}
-
-import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.ScalaReflection
-import org.apache.spark.sql.types.{ObjectType, StructType}
-
-/**
- * A factory for constructing encoders that convert Scala's product type to/from the Spark SQL
- * internal binary representation.
- */
-object ProductEncoder {
-  def apply[T <: Product : TypeTag]: ClassEncoder[T] = {
-    // We convert the not-serializable TypeTag into StructType and ClassTag.
-    val mirror = typeTag[T].mirror
-    val cls = mirror.runtimeClass(typeTag[T].tpe)
-
-    val inputObject = BoundReference(0, ObjectType(cls), nullable = true)
-    val extractExpression = ScalaReflection.extractorsFor[T](inputObject)
-    val constructExpression = ScalaReflection.constructorFor[T]
-
-    new ClassEncoder[T](
-      extractExpression.dataType,
-      extractExpression.flatten,
-      constructExpression,
-      ClassTag[T](cls))
-  }
-}

http://git-wip-us.apache.org/repos/asf/spark/blob/5a5f6590/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
index e9cc00a..0b42130 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
@@ -31,13 +31,14 @@ import org.apache.spark.unsafe.types.UTF8String
  * internal binary representation.
  */
 object RowEncoder {
-  def apply(schema: StructType): ClassEncoder[Row] = {
+  def apply(schema: StructType): ExpressionEncoder[Row] = {
     val cls = classOf[Row]
     val inputObject = BoundReference(0, ObjectType(cls), nullable = true)
     val extractExpressions = extractorsFor(inputObject, schema)
     val constructExpression = constructorFor(schema)
-    new ClassEncoder[Row](
+    new ExpressionEncoder[Row](
       schema,
+      flat = false,
       extractExpressions.asInstanceOf[CreateStruct].children,
       constructExpression,
       ClassTag(cls))

http://git-wip-us.apache.org/repos/asf/spark/blob/5a5f6590/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/package.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/package.scala
new file mode 100644
index 0000000..d4642a5
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/package.scala
@@ -0,0 +1,26 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst
+
+package object encoders {
+  private[sql] def encoderFor[A : Encoder]: ExpressionEncoder[A] = implicitly[Encoder[A]] match {
+    case e: ExpressionEncoder[A] => e
+    case _ => sys.error(s"Only expression encoders are supported today")
+  }
+}
+

http://git-wip-us.apache.org/repos/asf/spark/blob/5a5f6590/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/primitiveTypes.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/primitiveTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/primitiveTypes.scala
deleted file mode 100644
index a93f2d7..0000000
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/primitiveTypes.scala
+++ /dev/null
@@ -1,100 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.catalyst.encoders
-
-import scala.reflect.ClassTag
-
-import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection
-import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.unsafe.types.UTF8String
-import org.apache.spark.sql.types._
-
-/** An encoder for primitive Long types. */
-case class LongEncoder(fieldName: String = "value", ordinal: Int = 0) extends Encoder[Long] {
-  private val row = UnsafeRow.createFromByteArray(64, 1)
-
-  override def clsTag: ClassTag[Long] = ClassTag.Long
-  override def schema: StructType =
-    StructType(StructField(fieldName, LongType) :: Nil)
-
-  override def fromRow(row: InternalRow): Long = row.getLong(ordinal)
-
-  override def toRow(t: Long): InternalRow = {
-    row.setLong(ordinal, t)
-    row
-  }
-
-  override def bindOrdinals(schema: Seq[Attribute]): Encoder[Long] = this
-  override def bind(schema: Seq[Attribute]): Encoder[Long] = this
-  override def rebind(oldSchema: Seq[Attribute], newSchema: Seq[Attribute]): Encoder[Long] = this
-}
-
-/** An encoder for primitive Integer types. */
-case class IntEncoder(fieldName: String = "value", ordinal: Int = 0) extends Encoder[Int] {
-  private val row = UnsafeRow.createFromByteArray(64, 1)
-
-  override def clsTag: ClassTag[Int] = ClassTag.Int
-  override def schema: StructType =
-    StructType(StructField(fieldName, IntegerType) :: Nil)
-
-  override def fromRow(row: InternalRow): Int = row.getInt(ordinal)
-
-  override def toRow(t: Int): InternalRow = {
-    row.setInt(ordinal, t)
-    row
-  }
-
-  override def bindOrdinals(schema: Seq[Attribute]): Encoder[Int] = this
-  override def bind(schema: Seq[Attribute]): Encoder[Int] = this
-  override def rebind(oldSchema: Seq[Attribute], newSchema: Seq[Attribute]): Encoder[Int] = this
-}
-
-/** An encoder for String types. */
-case class StringEncoder(
-    fieldName: String = "value",
-    ordinal: Int = 0) extends Encoder[String] {
-
-  val record = new SpecificMutableRow(StringType :: Nil)
-
-  @transient
-  lazy val projection =
-    GenerateUnsafeProjection.generate(BoundReference(0, StringType, true) :: Nil)
-
-  override def schema: StructType =
-    StructType(
-      StructField("value", StringType, nullable = false) :: Nil)
-
-  override def clsTag: ClassTag[String] = scala.reflect.classTag[String]
-
-
-  override final def fromRow(row: InternalRow): String = {
-    row.getString(ordinal)
-  }
-
-  override final def toRow(value: String): InternalRow = {
-    val utf8String = UTF8String.fromString(value)
-    record(0) = utf8String
-    // TODO: this is a bit of a hack to produce UnsafeRows
-    projection(record)
-  }
-
-  override def bindOrdinals(schema: Seq[Attribute]): Encoder[String] = this
-  override def bind(schema: Seq[Attribute]): Encoder[String] = this
-  override def rebind(oldSchema: Seq[Attribute], newSchema: Seq[Attribute]): Encoder[String] = this
-}

http://git-wip-us.apache.org/repos/asf/spark/blob/5a5f6590/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/tuples.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/tuples.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/tuples.scala
deleted file mode 100644
index a48eeda..0000000
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/tuples.scala
+++ /dev/null
@@ -1,173 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.catalyst.encoders
-
-
-import scala.reflect.ClassTag
-
-import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions.Attribute
-import org.apache.spark.sql.types.{StructField, StructType}
-
-// Most of this file is codegen.
-// scalastyle:off
-
-/**
- * A set of composite encoders that take sub encoders and map each of their objects to a
- * Scala tuple.  Note that currently the implementation is fairly limited and only supports going
- * from an internal row to a tuple.
- */
-object TupleEncoder {
-
-  /** Code generator for composite tuple encoders. */
-  def main(args: Array[String]): Unit = {
-    (2 to 5).foreach { i =>
-      val types = (1 to i).map(t => s"T$t").mkString(", ")
-      val tupleType = s"($types)"
-      val args = (1 to i).map(t => s"e$t: Encoder[T$t]").mkString(", ")
-      val fields = (1 to i).map(t => s"""StructField("_$t", e$t.schema)""").mkString(", ")
-      val fromRow = (1 to i).map(t => s"e$t.fromRow(row)").mkString(", ")
-
-      println(
-        s"""
-          |class Tuple${i}Encoder[$types]($args) extends Encoder[$tupleType] {
-          |  val schema = StructType(Array($fields))
-          |
-          |  def clsTag: ClassTag[$tupleType] = scala.reflect.classTag[$tupleType]
-          |
-          |  def fromRow(row: InternalRow): $tupleType = {
-          |    ($fromRow)
-          |  }
-          |
-          |  override def toRow(t: $tupleType): InternalRow =
-          |    throw new UnsupportedOperationException("Tuple Encoders only support fromRow.")
-          |
-          |  override def bind(schema: Seq[Attribute]): Encoder[$tupleType] = {
-          |    this
-          |  }
-          |
-          |  override def rebind(oldSchema: Seq[Attribute], newSchema: Seq[Attribute]): Encoder[$tupleType] =
-          |    throw new UnsupportedOperationException("Tuple Encoders only support bind.")
-          |
-          |
-          |  override def bindOrdinals(schema: Seq[Attribute]): Encoder[$tupleType] =
-          |    throw new UnsupportedOperationException("Tuple Encoders only support bind.")
-          |}
-        """.stripMargin)
-    }
-  }
-}
-
-class Tuple2Encoder[T1, T2](e1: Encoder[T1], e2: Encoder[T2]) extends Encoder[(T1, T2)] {
-  val schema = StructType(Array(StructField("_1", e1.schema), StructField("_2", e2.schema)))
-
-  def clsTag: ClassTag[(T1, T2)] = scala.reflect.classTag[(T1, T2)]
-
-  def fromRow(row: InternalRow): (T1, T2) = {
-    (e1.fromRow(row), e2.fromRow(row))
-  }
-
-  override def toRow(t: (T1, T2)): InternalRow =
-    throw new UnsupportedOperationException("Tuple Encoders only support fromRow.")
-
-  override def bind(schema: Seq[Attribute]): Encoder[(T1, T2)] = {
-    this
-  }
-
-  override def rebind(oldSchema: Seq[Attribute], newSchema: Seq[Attribute]): Encoder[(T1, T2)] =
-    throw new UnsupportedOperationException("Tuple Encoders only support bind.")
-
-
-  override def bindOrdinals(schema: Seq[Attribute]): Encoder[(T1, T2)] =
-    throw new UnsupportedOperationException("Tuple Encoders only support bind.")
-}
-
-
-class Tuple3Encoder[T1, T2, T3](e1: Encoder[T1], e2: Encoder[T2], e3: Encoder[T3]) extends Encoder[(T1, T2, T3)] {
-  val schema = StructType(Array(StructField("_1", e1.schema), StructField("_2", e2.schema), StructField("_3", e3.schema)))
-
-  def clsTag: ClassTag[(T1, T2, T3)] = scala.reflect.classTag[(T1, T2, T3)]
-
-  def fromRow(row: InternalRow): (T1, T2, T3) = {
-    (e1.fromRow(row), e2.fromRow(row), e3.fromRow(row))
-  }
-
-  override def toRow(t: (T1, T2, T3)): InternalRow =
-    throw new UnsupportedOperationException("Tuple Encoders only support fromRow.")
-
-  override def bind(schema: Seq[Attribute]): Encoder[(T1, T2, T3)] = {
-    this
-  }
-
-  override def rebind(oldSchema: Seq[Attribute], newSchema: Seq[Attribute]): Encoder[(T1, T2, T3)] =
-    throw new UnsupportedOperationException("Tuple Encoders only support bind.")
-
-
-  override def bindOrdinals(schema: Seq[Attribute]): Encoder[(T1, T2, T3)] =
-    throw new UnsupportedOperationException("Tuple Encoders only support bind.")
-}
-
-
-class Tuple4Encoder[T1, T2, T3, T4](e1: Encoder[T1], e2: Encoder[T2], e3: Encoder[T3], e4: Encoder[T4]) extends Encoder[(T1, T2, T3, T4)] {
-  val schema = StructType(Array(StructField("_1", e1.schema), StructField("_2", e2.schema), StructField("_3", e3.schema), StructField("_4", e4.schema)))
-
-  def clsTag: ClassTag[(T1, T2, T3, T4)] = scala.reflect.classTag[(T1, T2, T3, T4)]
-
-  def fromRow(row: InternalRow): (T1, T2, T3, T4) = {
-    (e1.fromRow(row), e2.fromRow(row), e3.fromRow(row), e4.fromRow(row))
-  }
-
-  override def toRow(t: (T1, T2, T3, T4)): InternalRow =
-    throw new UnsupportedOperationException("Tuple Encoders only support fromRow.")
-
-  override def bind(schema: Seq[Attribute]): Encoder[(T1, T2, T3, T4)] = {
-    this
-  }
-
-  override def rebind(oldSchema: Seq[Attribute], newSchema: Seq[Attribute]): Encoder[(T1, T2, T3, T4)] =
-    throw new UnsupportedOperationException("Tuple Encoders only support bind.")
-
-
-  override def bindOrdinals(schema: Seq[Attribute]): Encoder[(T1, T2, T3, T4)] =
-    throw new UnsupportedOperationException("Tuple Encoders only support bind.")
-}
-
-
-class Tuple5Encoder[T1, T2, T3, T4, T5](e1: Encoder[T1], e2: Encoder[T2], e3: Encoder[T3], e4: Encoder[T4], e5: Encoder[T5]) extends Encoder[(T1, T2, T3, T4, T5)] {
-  val schema = StructType(Array(StructField("_1", e1.schema), StructField("_2", e2.schema), StructField("_3", e3.schema), StructField("_4", e4.schema), StructField("_5", e5.schema)))
-
-  def clsTag: ClassTag[(T1, T2, T3, T4, T5)] = scala.reflect.classTag[(T1, T2, T3, T4, T5)]
-
-  def fromRow(row: InternalRow): (T1, T2, T3, T4, T5) = {
-    (e1.fromRow(row), e2.fromRow(row), e3.fromRow(row), e4.fromRow(row), e5.fromRow(row))
-  }
-
-  override def toRow(t: (T1, T2, T3, T4, T5)): InternalRow =
-    throw new UnsupportedOperationException("Tuple Encoders only support fromRow.")
-
-  override def bind(schema: Seq[Attribute]): Encoder[(T1, T2, T3, T4, T5)] = {
-    this
-  }
-
-  override def rebind(oldSchema: Seq[Attribute], newSchema: Seq[Attribute]): Encoder[(T1, T2, T3, T4, T5)] =
-    throw new UnsupportedOperationException("Tuple Encoders only support bind.")
-
-
-  override def bindOrdinals(schema: Seq[Attribute]): Encoder[(T1, T2, T3, T4, T5)] =
-    throw new UnsupportedOperationException("Tuple Encoders only support bind.")
-}

http://git-wip-us.apache.org/repos/asf/spark/blob/5a5f6590/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
index 21a55a5..d2d3db0 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
@@ -17,7 +17,7 @@
 
 package org.apache.spark.sql.catalyst.plans.logical
 
-import org.apache.spark.sql.catalyst.encoders.Encoder
+import org.apache.spark.sql.catalyst.encoders._
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.expressions.aggregate.Utils
 import org.apache.spark.sql.catalyst.plans._
@@ -450,8 +450,8 @@ case object OneRowRelation extends LeafNode {
  */
 case class MapPartitions[T, U](
     func: Iterator[T] => Iterator[U],
-    tEncoder: Encoder[T],
-    uEncoder: Encoder[U],
+    tEncoder: ExpressionEncoder[T],
+    uEncoder: ExpressionEncoder[U],
     output: Seq[Attribute],
     child: LogicalPlan) extends UnaryNode {
   override def missingInput: AttributeSet = AttributeSet.empty
@@ -460,8 +460,8 @@ case class MapPartitions[T, U](
 /** Factory for constructing new `AppendColumn` nodes. */
 object AppendColumn {
   def apply[T : Encoder, U : Encoder](func: T => U, child: LogicalPlan): AppendColumn[T, U] = {
-    val attrs = implicitly[Encoder[U]].schema.toAttributes
-    new AppendColumn[T, U](func, implicitly[Encoder[T]], implicitly[Encoder[U]], attrs, child)
+    val attrs = encoderFor[U].schema.toAttributes
+    new AppendColumn[T, U](func, encoderFor[T], encoderFor[U], attrs, child)
   }
 }
 
@@ -472,8 +472,8 @@ object AppendColumn {
  */
 case class AppendColumn[T, U](
     func: T => U,
-    tEncoder: Encoder[T],
-    uEncoder: Encoder[U],
+    tEncoder: ExpressionEncoder[T],
+    uEncoder: ExpressionEncoder[U],
     newColumns: Seq[Attribute],
     child: LogicalPlan) extends UnaryNode {
   override def output: Seq[Attribute] = child.output ++ newColumns
@@ -488,11 +488,11 @@ object MapGroups {
       child: LogicalPlan): MapGroups[K, T, U] = {
     new MapGroups(
       func,
-      implicitly[Encoder[K]],
-      implicitly[Encoder[T]],
-      implicitly[Encoder[U]],
+      encoderFor[K],
+      encoderFor[T],
+      encoderFor[U],
       groupingAttributes,
-      implicitly[Encoder[U]].schema.toAttributes,
+      encoderFor[U].schema.toAttributes,
       child)
   }
 }
@@ -504,9 +504,9 @@ object MapGroups {
  */
 case class MapGroups[K, T, U](
     func: (K, Iterator[T]) => Iterator[U],
-    kEncoder: Encoder[K],
-    tEncoder: Encoder[T],
-    uEncoder: Encoder[U],
+    kEncoder: ExpressionEncoder[K],
+    tEncoder: ExpressionEncoder[T],
+    uEncoder: ExpressionEncoder[U],
     groupingAttributes: Seq[Attribute],
     output: Seq[Attribute],
     child: LogicalPlan) extends UnaryNode {

http://git-wip-us.apache.org/repos/asf/spark/blob/5a5f6590/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
new file mode 100644
index 0000000..a374da4
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
@@ -0,0 +1,291 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.encoders
+
+import scala.collection.mutable.ArrayBuffer
+import scala.reflect.runtime.universe._
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.catalyst._
+import org.apache.spark.sql.catalyst.util.ArrayData
+import org.apache.spark.sql.types.{StructField, ArrayType}
+
+case class RepeatedStruct(s: Seq[PrimitiveData])
+
+case class NestedArray(a: Array[Array[Int]])
+
+case class BoxedData(
+    intField: java.lang.Integer,
+    longField: java.lang.Long,
+    doubleField: java.lang.Double,
+    floatField: java.lang.Float,
+    shortField: java.lang.Short,
+    byteField: java.lang.Byte,
+    booleanField: java.lang.Boolean)
+
+case class RepeatedData(
+    arrayField: Seq[Int],
+    arrayFieldContainsNull: Seq[java.lang.Integer],
+    mapField: scala.collection.Map[Int, Long],
+    mapFieldNull: scala.collection.Map[Int, java.lang.Long],
+    structField: PrimitiveData)
+
+case class SpecificCollection(l: List[Int])
+
+class ExpressionEncoderSuite extends SparkFunSuite {
+
+  encodeDecodeTest(1)
+  encodeDecodeTest(1L)
+  encodeDecodeTest(1.toDouble)
+  encodeDecodeTest(1.toFloat)
+  encodeDecodeTest(true)
+  encodeDecodeTest(false)
+  encodeDecodeTest(1.toShort)
+  encodeDecodeTest(1.toByte)
+
+  encodeDecodeTest(PrimitiveData(1, 1, 1, 1, 1, 1, true))
+
+  // TODO: Support creating specific subclasses of Seq.
+  ignore("Specific collection types") { encodeDecodeTest(SpecificCollection(1 :: Nil)) }
+
+  encodeDecodeTest(
+    OptionalData(
+      Some(2), Some(2), Some(2), Some(2), Some(2), Some(2), Some(true),
+      Some(PrimitiveData(1, 1, 1, 1, 1, 1, true))))
+
+  encodeDecodeTest(OptionalData(None, None, None, None, None, None, None, None))
+
+  encodeDecodeTest(
+    BoxedData(1, 1L, 1.0, 1.0f, 1.toShort, 1.toByte, true))
+
+  encodeDecodeTest(
+    BoxedData(null, null, null, null, null, null, null))
+
+  encodeDecodeTest(
+    RepeatedStruct(PrimitiveData(1, 1, 1, 1, 1, 1, true) :: Nil))
+
+  encodeDecodeTest(
+    RepeatedData(
+      Seq(1, 2),
+      Seq(new Integer(1), null, new Integer(2)),
+      Map(1 -> 2L),
+      Map(1 -> null),
+      PrimitiveData(1, 1, 1, 1, 1, 1, true)))
+
+  encodeDecodeTest(("nullable Seq[Integer]", Seq[Integer](1, null)))
+
+  encodeDecodeTest(("Seq[(String, String)]",
+    Seq(("a", "b"))))
+  encodeDecodeTest(("Seq[(Int, Int)]",
+    Seq((1, 2))))
+  encodeDecodeTest(("Seq[(Long, Long)]",
+    Seq((1L, 2L))))
+  encodeDecodeTest(("Seq[(Float, Float)]",
+    Seq((1.toFloat, 2.toFloat))))
+  encodeDecodeTest(("Seq[(Double, Double)]",
+    Seq((1.toDouble, 2.toDouble))))
+  encodeDecodeTest(("Seq[(Short, Short)]",
+    Seq((1.toShort, 2.toShort))))
+  encodeDecodeTest(("Seq[(Byte, Byte)]",
+    Seq((1.toByte, 2.toByte))))
+  encodeDecodeTest(("Seq[(Boolean, Boolean)]",
+    Seq((true, false))))
+
+  // TODO: Decoding/encoding of complex maps.
+  ignore("complex maps") {
+    encodeDecodeTest(("Map[Int, (String, String)]",
+      Map(1 ->("a", "b"))))
+  }
+
+  encodeDecodeTest(("ArrayBuffer[(String, String)]",
+    ArrayBuffer(("a", "b"))))
+  encodeDecodeTest(("ArrayBuffer[(Int, Int)]",
+    ArrayBuffer((1, 2))))
+  encodeDecodeTest(("ArrayBuffer[(Long, Long)]",
+    ArrayBuffer((1L, 2L))))
+  encodeDecodeTest(("ArrayBuffer[(Float, Float)]",
+    ArrayBuffer((1.toFloat, 2.toFloat))))
+  encodeDecodeTest(("ArrayBuffer[(Double, Double)]",
+    ArrayBuffer((1.toDouble, 2.toDouble))))
+  encodeDecodeTest(("ArrayBuffer[(Short, Short)]",
+    ArrayBuffer((1.toShort, 2.toShort))))
+  encodeDecodeTest(("ArrayBuffer[(Byte, Byte)]",
+    ArrayBuffer((1.toByte, 2.toByte))))
+  encodeDecodeTest(("ArrayBuffer[(Boolean, Boolean)]",
+    ArrayBuffer((true, false))))
+
+  encodeDecodeTest(("Seq[Seq[(Int, Int)]]",
+    Seq(Seq((1, 2)))))
+
+  encodeDecodeTestCustom(("Array[Array[(Int, Int)]]",
+    Array(Array((1, 2)))))
+  { (l, r) => l._2(0)(0) == r._2(0)(0) }
+
+  encodeDecodeTestCustom(("Array[Array[(Int, Int)]]",
+    Array(Array(Array((1, 2))))))
+  { (l, r) => l._2(0)(0)(0) == r._2(0)(0)(0) }
+
+  encodeDecodeTestCustom(("Array[Array[Array[(Int, Int)]]]",
+    Array(Array(Array(Array((1, 2)))))))
+  { (l, r) => l._2(0)(0)(0)(0) == r._2(0)(0)(0)(0) }
+
+  encodeDecodeTestCustom(("Array[Array[Array[Array[(Int, Int)]]]]",
+    Array(Array(Array(Array(Array((1, 2))))))))
+  { (l, r) => l._2(0)(0)(0)(0)(0) == r._2(0)(0)(0)(0)(0) }
+
+
+  encodeDecodeTestCustom(("Array[Array[Integer]]",
+    Array(Array[Integer](1))))
+  { (l, r) => l._2(0)(0) == r._2(0)(0) }
+
+  encodeDecodeTestCustom(("Array[Array[Int]]",
+    Array(Array(1))))
+  { (l, r) => l._2(0)(0) == r._2(0)(0) }
+
+  encodeDecodeTestCustom(("Array[Array[Int]]",
+    Array(Array(Array(1)))))
+  { (l, r) => l._2(0)(0)(0) == r._2(0)(0)(0) }
+
+  encodeDecodeTestCustom(("Array[Array[Array[Int]]]",
+    Array(Array(Array(Array(1))))))
+  { (l, r) => l._2(0)(0)(0)(0) == r._2(0)(0)(0)(0) }
+
+  encodeDecodeTestCustom(("Array[Array[Array[Array[Int]]]]",
+    Array(Array(Array(Array(Array(1)))))))
+  { (l, r) => l._2(0)(0)(0)(0)(0) == r._2(0)(0)(0)(0)(0) }
+
+  encodeDecodeTest(("Array[Byte] null",
+    null: Array[Byte]))
+  encodeDecodeTestCustom(("Array[Byte]",
+    Array[Byte](1, 2, 3)))
+    { (l, r) => java.util.Arrays.equals(l._2, r._2) }
+
+  encodeDecodeTest(("Array[Int] null",
+    null: Array[Int]))
+  encodeDecodeTestCustom(("Array[Int]",
+    Array[Int](1, 2, 3)))
+    { (l, r) => java.util.Arrays.equals(l._2, r._2) }
+
+  encodeDecodeTest(("Array[Long] null",
+    null: Array[Long]))
+  encodeDecodeTestCustom(("Array[Long]",
+    Array[Long](1, 2, 3)))
+    { (l, r) => java.util.Arrays.equals(l._2, r._2) }
+
+  encodeDecodeTest(("Array[Double] null",
+    null: Array[Double]))
+  encodeDecodeTestCustom(("Array[Double]",
+    Array[Double](1, 2, 3)))
+    { (l, r) => java.util.Arrays.equals(l._2, r._2) }
+
+  encodeDecodeTest(("Array[Float] null",
+    null: Array[Float]))
+  encodeDecodeTestCustom(("Array[Float]",
+    Array[Float](1, 2, 3)))
+    { (l, r) => java.util.Arrays.equals(l._2, r._2) }
+
+  encodeDecodeTest(("Array[Boolean] null",
+    null: Array[Boolean]))
+  encodeDecodeTestCustom(("Array[Boolean]",
+    Array[Boolean](true, false)))
+    { (l, r) => java.util.Arrays.equals(l._2, r._2) }
+
+  encodeDecodeTest(("Array[Short] null",
+    null: Array[Short]))
+  encodeDecodeTestCustom(("Array[Short]",
+    Array[Short](1, 2, 3)))
+    { (l, r) => java.util.Arrays.equals(l._2, r._2) }
+
+  encodeDecodeTestCustom(("java.sql.Timestamp",
+    new java.sql.Timestamp(1)))
+    { (l, r) => l._2.toString == r._2.toString }
+
+  encodeDecodeTestCustom(("java.sql.Date", new java.sql.Date(1)))
+    { (l, r) => l._2.toString == r._2.toString }
+
+  /** Simplified encodeDecodeTestCustom, where the comparison function can be `Object.equals`. */
+  protected def encodeDecodeTest[T : TypeTag](inputData: T) =
+    encodeDecodeTestCustom[T](inputData)((l, r) => l == r)
+
+  /**
+   * Constructs a test that round-trips `t` through an encoder, checking the results to ensure it
+   * matches the original.
+   */
+  protected def encodeDecodeTestCustom[T : TypeTag](
+      inputData: T)(
+      c: (T, T) => Boolean) = {
+    test(s"encode/decode: $inputData - ${inputData.getClass.getName}") {
+      val encoder = try ExpressionEncoder[T]() catch {
+        case e: Exception =>
+          fail(s"Exception thrown generating encoder", e)
+      }
+      val convertedData = encoder.toRow(inputData)
+      val schema = encoder.schema.toAttributes
+      val boundEncoder = encoder.resolve(schema).bind(schema)
+      val convertedBack = try boundEncoder.fromRow(convertedData) catch {
+        case e: Exception =>
+          fail(
+           s"""Exception thrown while decoding
+              |Converted: $convertedData
+              |Schema: ${schema.mkString(",")}
+              |${encoder.schema.treeString}
+              |
+              |Encoder:
+              |$boundEncoder
+              |
+            """.stripMargin, e)
+      }
+
+      if (!c(inputData, convertedBack)) {
+        val types = convertedBack match {
+          case c: Product =>
+            c.productIterator.filter(_ != null).map(_.getClass.getName).mkString(",")
+          case other => other.getClass.getName
+        }
+
+
+        val encodedData = try {
+          convertedData.toSeq(encoder.schema).zip(encoder.schema).map {
+            case (a: ArrayData, StructField(_, at: ArrayType, _, _)) =>
+              a.toArray[Any](at.elementType).toSeq
+            case (other, _) =>
+              other
+          }.mkString("[", ",", "]")
+        } catch {
+          case e: Throwable => s"Failed to toSeq: $e"
+        }
+
+        fail(
+          s"""Encoded/Decoded data does not match input data
+             |
+             |in:  $inputData
+             |out: $convertedBack
+             |types: $types
+             |
+             |Encoded Data: $encodedData
+             |Schema: ${schema.mkString(",")}
+             |${encoder.schema.treeString}
+             |
+             |Extract Expressions:
+             |$boundEncoder
+         """.stripMargin)
+        }
+      }
+
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/5a5f6590/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/PrimitiveEncoderSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/PrimitiveEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/PrimitiveEncoderSuite.scala
deleted file mode 100644
index 52f8383..0000000
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/PrimitiveEncoderSuite.scala
+++ /dev/null
@@ -1,43 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.catalyst.encoders
-
-import org.apache.spark.SparkFunSuite
-
-class PrimitiveEncoderSuite extends SparkFunSuite {
-  test("long encoder") {
-    val enc = new LongEncoder()
-    val row = enc.toRow(10)
-    assert(row.getLong(0) == 10)
-    assert(enc.fromRow(row) == 10)
-  }
-
-  test("int encoder") {
-    val enc = new IntEncoder()
-    val row = enc.toRow(10)
-    assert(row.getInt(0) == 10)
-    assert(enc.fromRow(row) == 10)
-  }
-
-  test("string encoder") {
-    val enc = new StringEncoder()
-    val row = enc.toRow("test")
-    assert(row.getString(0) == "test")
-    assert(enc.fromRow(row) == "test")
-  }
-}

http://git-wip-us.apache.org/repos/asf/spark/blob/5a5f6590/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoderSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoderSuite.scala
deleted file mode 100644
index 008d0be..0000000
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoderSuite.scala
+++ /dev/null
@@ -1,282 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.catalyst.encoders
-
-import scala.collection.mutable.ArrayBuffer
-import scala.reflect.runtime.universe._
-
-import org.apache.spark.SparkFunSuite
-import org.apache.spark.sql.catalyst._
-import org.apache.spark.sql.catalyst.util.ArrayData
-import org.apache.spark.sql.types.{StructField, ArrayType}
-
-case class RepeatedStruct(s: Seq[PrimitiveData])
-
-case class NestedArray(a: Array[Array[Int]])
-
-case class BoxedData(
-    intField: java.lang.Integer,
-    longField: java.lang.Long,
-    doubleField: java.lang.Double,
-    floatField: java.lang.Float,
-    shortField: java.lang.Short,
-    byteField: java.lang.Byte,
-    booleanField: java.lang.Boolean)
-
-case class RepeatedData(
-    arrayField: Seq[Int],
-    arrayFieldContainsNull: Seq[java.lang.Integer],
-    mapField: scala.collection.Map[Int, Long],
-    mapFieldNull: scala.collection.Map[Int, java.lang.Long],
-    structField: PrimitiveData)
-
-case class SpecificCollection(l: List[Int])
-
-class ProductEncoderSuite extends SparkFunSuite {
-
-  encodeDecodeTest(PrimitiveData(1, 1, 1, 1, 1, 1, true))
-
-  // TODO: Support creating specific subclasses of Seq.
-  ignore("Specific collection types") { encodeDecodeTest(SpecificCollection(1 :: Nil)) }
-
-  encodeDecodeTest(
-    OptionalData(
-      Some(2), Some(2), Some(2), Some(2), Some(2), Some(2), Some(true),
-      Some(PrimitiveData(1, 1, 1, 1, 1, 1, true))))
-
-  encodeDecodeTest(OptionalData(None, None, None, None, None, None, None, None))
-
-  encodeDecodeTest(
-    BoxedData(1, 1L, 1.0, 1.0f, 1.toShort, 1.toByte, true))
-
-  encodeDecodeTest(
-    BoxedData(null, null, null, null, null, null, null))
-
-  encodeDecodeTest(
-    RepeatedStruct(PrimitiveData(1, 1, 1, 1, 1, 1, true) :: Nil))
-
-  encodeDecodeTest(
-    RepeatedData(
-      Seq(1, 2),
-      Seq(new Integer(1), null, new Integer(2)),
-      Map(1 -> 2L),
-      Map(1 -> null),
-      PrimitiveData(1, 1, 1, 1, 1, 1, true)))
-
-  encodeDecodeTest(("nullable Seq[Integer]", Seq[Integer](1, null)))
-
-  encodeDecodeTest(("Seq[(String, String)]",
-    Seq(("a", "b"))))
-  encodeDecodeTest(("Seq[(Int, Int)]",
-    Seq((1, 2))))
-  encodeDecodeTest(("Seq[(Long, Long)]",
-    Seq((1L, 2L))))
-  encodeDecodeTest(("Seq[(Float, Float)]",
-    Seq((1.toFloat, 2.toFloat))))
-  encodeDecodeTest(("Seq[(Double, Double)]",
-    Seq((1.toDouble, 2.toDouble))))
-  encodeDecodeTest(("Seq[(Short, Short)]",
-    Seq((1.toShort, 2.toShort))))
-  encodeDecodeTest(("Seq[(Byte, Byte)]",
-    Seq((1.toByte, 2.toByte))))
-  encodeDecodeTest(("Seq[(Boolean, Boolean)]",
-    Seq((true, false))))
-
-  // TODO: Decoding/encoding of complex maps.
-  ignore("complex maps") {
-    encodeDecodeTest(("Map[Int, (String, String)]",
-      Map(1 ->("a", "b"))))
-  }
-
-  encodeDecodeTest(("ArrayBuffer[(String, String)]",
-    ArrayBuffer(("a", "b"))))
-  encodeDecodeTest(("ArrayBuffer[(Int, Int)]",
-    ArrayBuffer((1, 2))))
-  encodeDecodeTest(("ArrayBuffer[(Long, Long)]",
-    ArrayBuffer((1L, 2L))))
-  encodeDecodeTest(("ArrayBuffer[(Float, Float)]",
-    ArrayBuffer((1.toFloat, 2.toFloat))))
-  encodeDecodeTest(("ArrayBuffer[(Double, Double)]",
-    ArrayBuffer((1.toDouble, 2.toDouble))))
-  encodeDecodeTest(("ArrayBuffer[(Short, Short)]",
-    ArrayBuffer((1.toShort, 2.toShort))))
-  encodeDecodeTest(("ArrayBuffer[(Byte, Byte)]",
-    ArrayBuffer((1.toByte, 2.toByte))))
-  encodeDecodeTest(("ArrayBuffer[(Boolean, Boolean)]",
-    ArrayBuffer((true, false))))
-
-  encodeDecodeTest(("Seq[Seq[(Int, Int)]]",
-    Seq(Seq((1, 2)))))
-
-  encodeDecodeTestCustom(("Array[Array[(Int, Int)]]",
-    Array(Array((1, 2)))))
-  { (l, r) => l._2(0)(0) == r._2(0)(0) }
-
-  encodeDecodeTestCustom(("Array[Array[(Int, Int)]]",
-    Array(Array(Array((1, 2))))))
-  { (l, r) => l._2(0)(0)(0) == r._2(0)(0)(0) }
-
-  encodeDecodeTestCustom(("Array[Array[Array[(Int, Int)]]]",
-    Array(Array(Array(Array((1, 2)))))))
-  { (l, r) => l._2(0)(0)(0)(0) == r._2(0)(0)(0)(0) }
-
-  encodeDecodeTestCustom(("Array[Array[Array[Array[(Int, Int)]]]]",
-    Array(Array(Array(Array(Array((1, 2))))))))
-  { (l, r) => l._2(0)(0)(0)(0)(0) == r._2(0)(0)(0)(0)(0) }
-
-
-  encodeDecodeTestCustom(("Array[Array[Integer]]",
-    Array(Array[Integer](1))))
-  { (l, r) => l._2(0)(0) == r._2(0)(0) }
-
-  encodeDecodeTestCustom(("Array[Array[Int]]",
-    Array(Array(1))))
-  { (l, r) => l._2(0)(0) == r._2(0)(0) }
-
-  encodeDecodeTestCustom(("Array[Array[Int]]",
-    Array(Array(Array(1)))))
-  { (l, r) => l._2(0)(0)(0) == r._2(0)(0)(0) }
-
-  encodeDecodeTestCustom(("Array[Array[Array[Int]]]",
-    Array(Array(Array(Array(1))))))
-  { (l, r) => l._2(0)(0)(0)(0) == r._2(0)(0)(0)(0) }
-
-  encodeDecodeTestCustom(("Array[Array[Array[Array[Int]]]]",
-    Array(Array(Array(Array(Array(1)))))))
-  { (l, r) => l._2(0)(0)(0)(0)(0) == r._2(0)(0)(0)(0)(0) }
-
-  encodeDecodeTest(("Array[Byte] null",
-    null: Array[Byte]))
-  encodeDecodeTestCustom(("Array[Byte]",
-    Array[Byte](1, 2, 3)))
-    { (l, r) => java.util.Arrays.equals(l._2, r._2) }
-
-  encodeDecodeTest(("Array[Int] null",
-    null: Array[Int]))
-  encodeDecodeTestCustom(("Array[Int]",
-    Array[Int](1, 2, 3)))
-    { (l, r) => java.util.Arrays.equals(l._2, r._2) }
-
-  encodeDecodeTest(("Array[Long] null",
-    null: Array[Long]))
-  encodeDecodeTestCustom(("Array[Long]",
-    Array[Long](1, 2, 3)))
-    { (l, r) => java.util.Arrays.equals(l._2, r._2) }
-
-  encodeDecodeTest(("Array[Double] null",
-    null: Array[Double]))
-  encodeDecodeTestCustom(("Array[Double]",
-    Array[Double](1, 2, 3)))
-    { (l, r) => java.util.Arrays.equals(l._2, r._2) }
-
-  encodeDecodeTest(("Array[Float] null",
-    null: Array[Float]))
-  encodeDecodeTestCustom(("Array[Float]",
-    Array[Float](1, 2, 3)))
-    { (l, r) => java.util.Arrays.equals(l._2, r._2) }
-
-  encodeDecodeTest(("Array[Boolean] null",
-    null: Array[Boolean]))
-  encodeDecodeTestCustom(("Array[Boolean]",
-    Array[Boolean](true, false)))
-    { (l, r) => java.util.Arrays.equals(l._2, r._2) }
-
-  encodeDecodeTest(("Array[Short] null",
-    null: Array[Short]))
-  encodeDecodeTestCustom(("Array[Short]",
-    Array[Short](1, 2, 3)))
-    { (l, r) => java.util.Arrays.equals(l._2, r._2) }
-
-  encodeDecodeTestCustom(("java.sql.Timestamp",
-    new java.sql.Timestamp(1)))
-    { (l, r) => l._2.toString == r._2.toString }
-
-  encodeDecodeTestCustom(("java.sql.Date", new java.sql.Date(1)))
-    { (l, r) => l._2.toString == r._2.toString }
-
-  /** Simplified encodeDecodeTestCustom, where the comparison function can be `Object.equals`. */
-  protected def encodeDecodeTest[T <: Product : TypeTag](inputData: T) =
-    encodeDecodeTestCustom[T](inputData)((l, r) => l == r)
-
-  /**
-   * Constructs a test that round-trips `t` through an encoder, checking the results to ensure it
-   * matches the original.
-   */
-  protected def encodeDecodeTestCustom[T <: Product : TypeTag](
-      inputData: T)(
-      c: (T, T) => Boolean) = {
-    test(s"encode/decode: $inputData") {
-      val encoder = try ProductEncoder[T] catch {
-        case e: Exception =>
-          fail(s"Exception thrown generating encoder", e)
-      }
-      val convertedData = encoder.toRow(inputData)
-      val schema = encoder.schema.toAttributes
-      val boundEncoder = encoder.bind(schema)
-      val convertedBack = try boundEncoder.fromRow(convertedData) catch {
-        case e: Exception =>
-          fail(
-           s"""Exception thrown while decoding
-              |Converted: $convertedData
-              |Schema: ${schema.mkString(",")}
-              |${encoder.schema.treeString}
-              |
-              |Construct Expressions:
-              |${boundEncoder.constructExpression.treeString}
-              |
-            """.stripMargin, e)
-      }
-
-      if (!c(inputData, convertedBack)) {
-        val types =
-          convertedBack.productIterator.filter(_ != null).map(_.getClass.getName).mkString(",")
-
-        val encodedData = try {
-          convertedData.toSeq(encoder.schema).zip(encoder.schema).map {
-            case (a: ArrayData, StructField(_, at: ArrayType, _, _)) =>
-              a.toArray[Any](at.elementType).toSeq
-            case (other, _) =>
-              other
-          }.mkString("[", ",", "]")
-        } catch {
-          case e: Throwable => s"Failed to toSeq: $e"
-        }
-
-        fail(
-          s"""Encoded/Decoded data does not match input data
-             |
-             |in:  $inputData
-             |out: $convertedBack
-             |types: $types
-             |
-             |Encoded Data: $encodedData
-             |Schema: ${schema.mkString(",")}
-             |${encoder.schema.treeString}
-             |
-             |Extract Expressions:
-             |${boundEncoder.extractExpressions.map(_.treeString).mkString("\n")}
-             |
-             |Construct Expressions:
-             |${boundEncoder.constructExpression.treeString}
-             |
-         """.stripMargin)
-        }
-      }
-
-  }
-}

http://git-wip-us.apache.org/repos/asf/spark/blob/5a5f6590/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
index 32d9b0b..aa817a0 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
@@ -267,7 +267,7 @@ class DataFrame private[sql](
    * @since 1.6.0
    */
   @Experimental
-  def as[U : Encoder]: Dataset[U] = new Dataset[U](sqlContext, queryExecution)
+  def as[U : Encoder]: Dataset[U] = new Dataset[U](sqlContext, logicalPlan)
 
   /**
    * Returns a new [[DataFrame]] with columns renamed. This can be quite convenient in conversion

http://git-wip-us.apache.org/repos/asf/spark/blob/5a5f6590/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index 96213c7..e0ab5f5 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -21,6 +21,7 @@ import org.apache.spark.annotation.Experimental
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.catalyst.encoders._
 import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.plans.Inner
 import org.apache.spark.sql.catalyst.plans.logical._
 import org.apache.spark.sql.execution.QueryExecution
 import org.apache.spark.sql.types.StructType
@@ -53,15 +54,21 @@ import org.apache.spark.sql.types.StructType
  * @since 1.6.0
  */
 @Experimental
-class Dataset[T] private[sql](
+class Dataset[T] private(
     @transient val sqlContext: SQLContext,
-    @transient val queryExecution: QueryExecution)(
-    implicit val encoder: Encoder[T]) extends Serializable {
+    @transient val queryExecution: QueryExecution,
+    unresolvedEncoder: Encoder[T]) extends Serializable {
+
+  /** The encoder for this [[Dataset]] that has been resolved to its output schema. */
+  private[sql] implicit val encoder: ExpressionEncoder[T] = unresolvedEncoder match {
+    case e: ExpressionEncoder[T] => e.resolve(queryExecution.analyzed.output)
+    case _ => throw new IllegalArgumentException("Only expression encoders are currently supported")
+  }
 
   private implicit def classTag = encoder.clsTag
 
   private[sql] def this(sqlContext: SQLContext, plan: LogicalPlan)(implicit encoder: Encoder[T]) =
-    this(sqlContext, new QueryExecution(sqlContext, plan))
+    this(sqlContext, new QueryExecution(sqlContext, plan), encoder)
 
   /** Returns the schema of the encoded form of the objects in this [[Dataset]]. */
   def schema: StructType = encoder.schema
@@ -76,7 +83,9 @@ class Dataset[T] private[sql](
    * TODO: document binding rules
    * @since 1.6.0
    */
-  def as[U : Encoder]: Dataset[U] = new Dataset(sqlContext, queryExecution)(implicitly[Encoder[U]])
+  def as[U : Encoder]: Dataset[U] = {
+    new Dataset(sqlContext, queryExecution, encoderFor[U])
+  }
 
   /**
    * Applies a logical alias to this [[Dataset]] that can be used to disambiguate columns that have
@@ -103,7 +112,7 @@ class Dataset[T] private[sql](
    * @since 1.6.0
    */
   def rdd: RDD[T] = {
-    val tEnc = implicitly[Encoder[T]]
+    val tEnc = encoderFor[T]
     val input = queryExecution.analyzed.output
     queryExecution.toRdd.mapPartitions { iter =>
       val bound = tEnc.bind(input)
@@ -150,9 +159,9 @@ class Dataset[T] private[sql](
       sqlContext,
       MapPartitions[T, U](
         func,
-        implicitly[Encoder[T]],
-        implicitly[Encoder[U]],
-        implicitly[Encoder[U]].schema.toAttributes,
+        encoderFor[T],
+        encoderFor[U],
+        encoderFor[U].schema.toAttributes,
         logicalPlan))
   }
 
@@ -209,8 +218,8 @@ class Dataset[T] private[sql](
     val executed = sqlContext.executePlan(withGroupingKey)
 
     new GroupedDataset(
-      implicitly[Encoder[K]].bindOrdinals(withGroupingKey.newColumns),
-      implicitly[Encoder[T]].bind(inputPlan.output),
+      encoderFor[K].resolve(withGroupingKey.newColumns),
+      encoderFor[T].bind(inputPlan.output),
       executed,
       inputPlan.output,
       withGroupingKey.newColumns)
@@ -221,6 +230,18 @@ class Dataset[T] private[sql](
    * ****************** */
 
   /**
+   * Selects a set of column based expressions.
+   * {{{
+   *   df.select($"colA", $"colB" + 1)
+   * }}}
+   * @group dfops
+   * @since 1.3.0
+   */
+  // Copied from Dataframe to make sure we don't have invalid overloads.
+  @scala.annotation.varargs
+  def select(cols: Column*): DataFrame = toDF().select(cols: _*)
+
+  /**
    * Returns a new [[Dataset]] by computing the given [[Column]] expression for each element.
    *
    * {{{
@@ -233,88 +254,64 @@ class Dataset[T] private[sql](
     new Dataset[U1](sqlContext, Project(Alias(c1.expr, "_1")() :: Nil, logicalPlan))
   }
 
-  // Codegen
-  // scalastyle:off
-
-  /** sbt scalaShell; println(Seq(1).toDS().genSelect) */
-  private def genSelect: String = {
-    (2 to 5).map { n =>
-      val types = (1 to n).map(i =>s"U$i").mkString(", ")
-      val args = (1 to n).map(i => s"c$i: TypedColumn[U$i]").mkString(", ")
-      val encoders = (1 to n).map(i => s"c$i.encoder").mkString(", ")
-      val schema = (1 to n).map(i => s"""Alias(c$i.expr, "_$i")()""").mkString(" :: ")
-      s"""
-         |/**
-         | * Returns a new [[Dataset]] by computing the given [[Column]] expressions for each element.
-         | * @since 1.6.0
-         | */
-         |def select[$types]($args): Dataset[($types)] = {
-         |  implicit val te = new Tuple${n}Encoder($encoders)
-         |  new Dataset[($types)](sqlContext,
-         |    Project(
-         |      $schema :: Nil,
-         |      logicalPlan))
-         |}
-         |
-       """.stripMargin
-    }.mkString("\n")
+  /**
+   * Internal helper function for building typed selects that return tuples.  For simplicity and
+   * code reuse, we do this without the help of the type system and then use helper functions
+   * that cast appropriately for the user facing interface.
+   */
+  protected def selectUntyped(columns: TypedColumn[_]*): Dataset[_] = {
+    val aliases = columns.zipWithIndex.map { case (c, i) => Alias(c.expr, s"_${i + 1}")() }
+    val unresolvedPlan = Project(aliases, logicalPlan)
+    val execution = new QueryExecution(sqlContext, unresolvedPlan)
+    // Rebind the encoders to the nested schema that will be produced by the select.
+    val encoders = columns.map(_.encoder.asInstanceOf[ExpressionEncoder[_]]).zip(aliases).map {
+      case (e: ExpressionEncoder[_], a) if !e.flat =>
+        e.nested(a.toAttribute).resolve(execution.analyzed.output)
+      case (e, a) =>
+        e.unbind(a.toAttribute :: Nil).resolve(execution.analyzed.output)
+    }
+    new Dataset(sqlContext, execution, ExpressionEncoder.tuple(encoders))
   }
 
   /**
    * Returns a new [[Dataset]] by computing the given [[Column]] expressions for each element.
    * @since 1.6.0
    */
-  def select[U1, U2](c1: TypedColumn[U1], c2: TypedColumn[U2]): Dataset[(U1, U2)] = {
-    implicit val te = new Tuple2Encoder(c1.encoder, c2.encoder)
-    new Dataset[(U1, U2)](sqlContext,
-      Project(
-        Alias(c1.expr, "_1")() :: Alias(c2.expr, "_2")() :: Nil,
-        logicalPlan))
-  }
-
-
+  def select[U1, U2](c1: TypedColumn[U1], c2: TypedColumn[U2]): Dataset[(U1, U2)] =
+    selectUntyped(c1, c2).asInstanceOf[Dataset[(U1, U2)]]
 
   /**
    * Returns a new [[Dataset]] by computing the given [[Column]] expressions for each element.
    * @since 1.6.0
    */
-  def select[U1, U2, U3](c1: TypedColumn[U1], c2: TypedColumn[U2], c3: TypedColumn[U3]): Dataset[(U1, U2, U3)] = {
-    implicit val te = new Tuple3Encoder(c1.encoder, c2.encoder, c3.encoder)
-    new Dataset[(U1, U2, U3)](sqlContext,
-      Project(
-        Alias(c1.expr, "_1")() :: Alias(c2.expr, "_2")() :: Alias(c3.expr, "_3")() :: Nil,
-        logicalPlan))
-  }
-
-
+  def select[U1, U2, U3](
+      c1: TypedColumn[U1],
+      c2: TypedColumn[U2],
+      c3: TypedColumn[U3]): Dataset[(U1, U2, U3)] =
+    selectUntyped(c1, c2, c3).asInstanceOf[Dataset[(U1, U2, U3)]]
 
   /**
    * Returns a new [[Dataset]] by computing the given [[Column]] expressions for each element.
    * @since 1.6.0
    */
-  def select[U1, U2, U3, U4](c1: TypedColumn[U1], c2: TypedColumn[U2], c3: TypedColumn[U3], c4: TypedColumn[U4]): Dataset[(U1, U2, U3, U4)] = {
-    implicit val te = new Tuple4Encoder(c1.encoder, c2.encoder, c3.encoder, c4.encoder)
-    new Dataset[(U1, U2, U3, U4)](sqlContext,
-      Project(
-        Alias(c1.expr, "_1")() :: Alias(c2.expr, "_2")() :: Alias(c3.expr, "_3")() :: Alias(c4.expr, "_4")() :: Nil,
-        logicalPlan))
-  }
-
-
+  def select[U1, U2, U3, U4](
+      c1: TypedColumn[U1],
+      c2: TypedColumn[U2],
+      c3: TypedColumn[U3],
+      c4: TypedColumn[U4]): Dataset[(U1, U2, U3, U4)] =
+    selectUntyped(c1, c2, c3, c4).asInstanceOf[Dataset[(U1, U2, U3, U4)]]
 
   /**
    * Returns a new [[Dataset]] by computing the given [[Column]] expressions for each element.
    * @since 1.6.0
    */
-  def select[U1, U2, U3, U4, U5](c1: TypedColumn[U1], c2: TypedColumn[U2], c3: TypedColumn[U3], c4: TypedColumn[U4], c5: TypedColumn[U5]): Dataset[(U1, U2, U3, U4, U5)] = {
-    implicit val te = new Tuple5Encoder(c1.encoder, c2.encoder, c3.encoder, c4.encoder, c5.encoder)
-    new Dataset[(U1, U2, U3, U4, U5)](sqlContext,
-      Project(
-        Alias(c1.expr, "_1")() :: Alias(c2.expr, "_2")() :: Alias(c3.expr, "_3")() :: Alias(c4.expr, "_4")() :: Alias(c5.expr, "_5")() :: Nil,
-        logicalPlan))
-  }
-
-  // scalastyle:on
+  def select[U1, U2, U3, U4, U5](
+      c1: TypedColumn[U1],
+      c2: TypedColumn[U2],
+      c3: TypedColumn[U3],
+      c4: TypedColumn[U4],
+      c5: TypedColumn[U5]): Dataset[(U1, U2, U3, U4, U5)] =
+    selectUntyped(c1, c2, c3, c4, c5).asInstanceOf[Dataset[(U1, U2, U3, U4, U5)]]
 
   /* **************** *
    *  Set operations  *
@@ -360,6 +357,48 @@ class Dataset[T] private[sql](
    */
   def subtract(other: Dataset[T]): Dataset[T] = withPlan[T](other)(Except)
 
+  /* ****** *
+   *  Joins *
+   * ****** */
+
+  /**
+   * Joins this [[Dataset]] returning a [[Tuple2]] for each pair where `condition` evaluates to
+   * true.
+   *
+   * This is similar to the relation `join` function with one important difference in the
+   * result schema. Since `joinWith` preserves objects present on either side of the join, the
+   * result schema is similarly nested into a tuple under the column names `_1` and `_2`.
+   *
+   * This type of join can be useful both for preserving type-safety with the original object
+   * types as well as working with relational data where either side of the join has column
+   * names in common.
+   */
+  def joinWith[U](other: Dataset[U], condition: Column): Dataset[(T, U)] = {
+    val left = this.logicalPlan
+    val right = other.logicalPlan
+
+    val leftData = this.encoder match {
+      case e if e.flat => Alias(left.output.head, "_1")()
+      case _ => Alias(CreateStruct(left.output), "_1")()
+    }
+    val rightData = other.encoder match {
+      case e if e.flat => Alias(right.output.head, "_2")()
+      case _ => Alias(CreateStruct(right.output), "_2")()
+    }
+    val leftEncoder =
+      if (encoder.flat) encoder else encoder.nested(leftData.toAttribute)
+    val rightEncoder =
+      if (other.encoder.flat) other.encoder else other.encoder.nested(rightData.toAttribute)
+    implicit val tuple2Encoder: Encoder[(T, U)] =
+      ExpressionEncoder.tuple(leftEncoder, rightEncoder)
+
+    withPlan[(T, U)](other) { (left, right) =>
+      Project(
+        leftData :: rightData :: Nil,
+        Join(left, right, Inner, Some(condition.expr)))
+    }
+  }
+
   /* ************************** *
    *  Gather to Driver Actions  *
    * ************************** */
@@ -380,13 +419,10 @@ class Dataset[T] private[sql](
   private[sql] def logicalPlan = queryExecution.analyzed
 
   private[sql] def withPlan(f: LogicalPlan => LogicalPlan): Dataset[T] =
-    new Dataset[T](sqlContext, sqlContext.executePlan(f(logicalPlan)))
+    new Dataset[T](sqlContext, sqlContext.executePlan(f(logicalPlan)), encoder)
 
   private[sql] def withPlan[R : Encoder](
       other: Dataset[_])(
       f: (LogicalPlan, LogicalPlan) => LogicalPlan): Dataset[R] =
-    new Dataset[R](
-      sqlContext,
-      sqlContext.executePlan(
-        f(logicalPlan, other.logicalPlan)))
+    new Dataset[R](sqlContext, f(logicalPlan, other.logicalPlan))
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/5a5f6590/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
index 5e7198f..2cb9443 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
@@ -34,7 +34,7 @@ import org.apache.spark.rdd.RDD
 import org.apache.spark.scheduler.{SparkListener, SparkListenerApplicationEnd}
 import org.apache.spark.sql.SQLConf.SQLConfEntry
 import org.apache.spark.sql.catalyst.analysis._
-import org.apache.spark.sql.catalyst.encoders.Encoder
+import org.apache.spark.sql.catalyst.encoders.{encoderFor, Encoder}
 import org.apache.spark.sql.catalyst.errors.DialectException
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.optimizer.{DefaultOptimizer, Optimizer}
@@ -491,7 +491,7 @@ class SQLContext private[sql](
 
 
   def createDataset[T : Encoder](data: Seq[T]): Dataset[T] = {
-    val enc = implicitly[Encoder[T]]
+    val enc = encoderFor[T]
     val attributes = enc.schema.toAttributes
     val encoded = data.map(d => enc.toRow(d).copy())
     val plan = new LocalRelation(attributes, encoded)

http://git-wip-us.apache.org/repos/asf/spark/blob/5a5f6590/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala
index af8474d..f460a86 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala
@@ -37,11 +37,16 @@ import org.apache.spark.unsafe.types.UTF8String
 abstract class SQLImplicits {
   protected def _sqlContext: SQLContext
 
-  implicit def newProductEncoder[T <: Product : TypeTag]: Encoder[T] = ProductEncoder[T]
+  implicit def newProductEncoder[T <: Product : TypeTag]: Encoder[T] = ExpressionEncoder[T]()
 
-  implicit def newIntEncoder: Encoder[Int] = new IntEncoder()
-  implicit def newLongEncoder: Encoder[Long] = new LongEncoder()
-  implicit def newStringEncoder: Encoder[String] = new StringEncoder()
+  implicit def newIntEncoder: Encoder[Int] = ExpressionEncoder[Int](flat = true)
+  implicit def newLongEncoder: Encoder[Long] = ExpressionEncoder[Long](flat = true)
+  implicit def newDoubleEncoder: Encoder[Double] = ExpressionEncoder[Double](flat = true)
+  implicit def newFloatEncoder: Encoder[Float] = ExpressionEncoder[Float](flat = true)
+  implicit def newByteEncoder: Encoder[Byte] = ExpressionEncoder[Byte](flat = true)
+  implicit def newShortEncoder: Encoder[Short] = ExpressionEncoder[Short](flat = true)
+  implicit def newBooleanEncoder: Encoder[Boolean] = ExpressionEncoder[Boolean](flat = true)
+  implicit def newStringEncoder: Encoder[String] = ExpressionEncoder[String](flat = true)
 
   implicit def localSeqToDatasetHolder[T : Encoder](s: Seq[T]): DatasetHolder[T] = {
     DatasetHolder(_sqlContext.createDataset(s))

http://git-wip-us.apache.org/repos/asf/spark/blob/5a5f6590/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
index 2bb3dba..8993847 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql.execution
 import org.apache.spark.rdd.{PartitionwiseSampledRDD, RDD, ShuffledRDD}
 import org.apache.spark.shuffle.sort.SortShuffleManager
 import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.encoders.Encoder
+import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeRowJoiner
 import org.apache.spark.sql.catalyst.plans.physical._
@@ -319,8 +319,8 @@ case class OutputFaker(output: Seq[Attribute], child: SparkPlan) extends SparkPl
  */
 case class MapPartitions[T, U](
     func: Iterator[T] => Iterator[U],
-    tEncoder: Encoder[T],
-    uEncoder: Encoder[U],
+    tEncoder: ExpressionEncoder[T],
+    uEncoder: ExpressionEncoder[U],
     output: Seq[Attribute],
     child: SparkPlan) extends UnaryNode {
 
@@ -337,8 +337,8 @@ case class MapPartitions[T, U](
  */
 case class AppendColumns[T, U](
     func: T => U,
-    tEncoder: Encoder[T],
-    uEncoder: Encoder[U],
+    tEncoder: ExpressionEncoder[T],
+    uEncoder: ExpressionEncoder[U],
     newColumns: Seq[Attribute],
     child: SparkPlan) extends UnaryNode {
 
@@ -363,9 +363,9 @@ case class AppendColumns[T, U](
  */
 case class MapGroups[K, T, U](
     func: (K, Iterator[T]) => Iterator[U],
-    kEncoder: Encoder[K],
-    tEncoder: Encoder[T],
-    uEncoder: Encoder[U],
+    kEncoder: ExpressionEncoder[K],
+    tEncoder: ExpressionEncoder[T],
+    uEncoder: ExpressionEncoder[U],
     groupingAttributes: Seq[Attribute],
     output: Seq[Attribute],
     child: SparkPlan) extends UnaryNode {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org