You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ma...@apache.org on 2015/10/14 02:09:24 UTC

spark git commit: [SPARK-11090] [SQL] Constructor for Product types from InternalRow

Repository: spark
Updated Branches:
  refs/heads/master 3889b1c7a -> 328d1b3e4


[SPARK-11090] [SQL] Constructor for Product types from InternalRow

This is a first draft of the ability to construct expressions that will take a catalyst internal row and construct a Product (case class or tuple) that has fields with the correct names.  Support include:
 - Nested classes
 - Maps
 - Efficiently handling of arrays of primitive types

Not yet supported:
 - Case classes that require custom collection types (i.e. List instead of Seq).

Author: Michael Armbrust <mi...@databricks.com>

Closes #9100 from marmbrus/productContructor.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/328d1b3e
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/328d1b3e
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/328d1b3e

Branch: refs/heads/master
Commit: 328d1b3e4bc39cce653342e04f9e08af12dd7ed8
Parents: 3889b1c7
Author: Michael Armbrust <mi...@databricks.com>
Authored: Tue Oct 13 17:09:17 2015 -0700
Committer: Michael Armbrust <mi...@databricks.com>
Committed: Tue Oct 13 17:09:17 2015 -0700

----------------------------------------------------------------------
 .../catalyst/expressions/UnsafeArrayData.java   |   4 +
 .../spark/sql/catalyst/ScalaReflection.scala    | 302 ++++++++++++++-
 .../spark/sql/catalyst/encoders/Encoder.scala   |  14 +
 .../sql/catalyst/encoders/ProductEncoder.scala  |  26 +-
 .../sql/catalyst/expressions/objects.scala      | 154 +++++++-
 .../spark/sql/types/ArrayBasedMapData.scala     |   4 +
 .../org/apache/spark/sql/types/ArrayData.scala  |   5 +
 .../spark/sql/types/GenericArrayData.scala      |   4 +-
 .../catalyst/encoders/ProductEncoderSuite.scala | 369 ++++++++++++-------
 9 files changed, 723 insertions(+), 159 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/328d1b3e/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java
index 796f8ab..4c63abb 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java
@@ -74,6 +74,10 @@ public class UnsafeArrayData extends ArrayData {
     assert ordinal < numElements : "ordinal (" + ordinal + ") should < " + numElements;
   }
 
+  public Object[] array() {
+    throw new UnsupportedOperationException("Only supported on GenericArrayData.");
+  }
+
   /**
    * Construct a new UnsafeArrayData. The resulting UnsafeArrayData won't be usable until
    * `pointTo()` has been called, since the value returned by this constructor is equivalent

http://git-wip-us.apache.org/repos/asf/spark/blob/328d1b3e/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
index 8b733f2..8edd649 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
@@ -17,6 +17,7 @@
 
 package org.apache.spark.sql.catalyst
 
+import org.apache.spark.sql.catalyst.analysis.{UnresolvedExtractValue, UnresolvedAttribute}
 import org.apache.spark.sql.catalyst.util.DateTimeUtils
 import org.apache.spark.unsafe.types.UTF8String
 import org.apache.spark.util.Utils
@@ -80,6 +81,9 @@ trait ScalaReflection {
    * Returns the Spark SQL DataType for a given scala type.  Where this is not an exact mapping
    * to a native type, an ObjectType is returned. Special handling is also used for Arrays including
    * those that hold primitive types.
+   *
+   * Unlike `schemaFor`, this function doesn't do any massaging of types into the Spark SQL type
+   * system.  As a result, ObjectType will be returned for things like boxed Integers
    */
   def dataTypeFor(tpe: `Type`): DataType = tpe match {
     case t if t <:< definitions.IntTpe => IntegerType
@@ -114,6 +118,298 @@ trait ScalaReflection {
       }
   }
 
+  /**
+   * Given a type `T` this function constructs and ObjectType that holds a class of type
+   * Array[T].  Special handling is performed for primitive types to map them back to their raw
+   * JVM form instead of the Scala Array that handles auto boxing.
+   */
+  def arrayClassFor(tpe: `Type`): DataType = {
+    val cls = tpe match {
+      case t if t <:< definitions.IntTpe => classOf[Array[Int]]
+      case t if t <:< definitions.LongTpe => classOf[Array[Long]]
+      case t if t <:< definitions.DoubleTpe => classOf[Array[Double]]
+      case t if t <:< definitions.FloatTpe => classOf[Array[Float]]
+      case t if t <:< definitions.ShortTpe => classOf[Array[Short]]
+      case t if t <:< definitions.ByteTpe => classOf[Array[Byte]]
+      case t if t <:< definitions.BooleanTpe => classOf[Array[Boolean]]
+      case other =>
+        // There is probably a better way to do this, but I couldn't find it...
+        val elementType = dataTypeFor(other).asInstanceOf[ObjectType].cls
+        java.lang.reflect.Array.newInstance(elementType, 1).getClass
+
+    }
+    ObjectType(cls)
+  }
+
+  /**
+   * Returns an expression that can be used to construct an object of type `T` given a an input
+   * row with a compatible schema.  Fields of the row will be extracted using UnresolvedAttributes
+   * of the same name as the constructor arguments.  Nested classes will have their fields accessed
+   * using UnresolvedExtractValue.
+   */
+  def constructorFor[T : TypeTag]: Expression = constructorFor(typeOf[T], None)
+
+  protected def constructorFor(
+      tpe: `Type`,
+      path: Option[Expression]): Expression = ScalaReflectionLock.synchronized {
+
+    /** Returns the current path with a sub-field extracted. */
+    def addToPath(part: String) =
+      path
+        .map(p => UnresolvedExtractValue(p, expressions.Literal(part)))
+        .getOrElse(UnresolvedAttribute(part))
+
+    /** Returns the current path or throws an error. */
+    def getPath = path.getOrElse(sys.error("Constructors must start at a class type"))
+
+    tpe match {
+      case t if !dataTypeFor(t).isInstanceOf[ObjectType] =>
+        getPath
+
+      case t if t <:< localTypeOf[Option[_]] =>
+        val TypeRef(_, _, Seq(optType)) = t
+        val boxedType = optType match {
+          // For primitive types we must manually box the primitive value.
+          case t if t <:< definitions.IntTpe => Some(classOf[java.lang.Integer])
+          case t if t <:< definitions.LongTpe => Some(classOf[java.lang.Long])
+          case t if t <:< definitions.DoubleTpe => Some(classOf[java.lang.Double])
+          case t if t <:< definitions.FloatTpe => Some(classOf[java.lang.Float])
+          case t if t <:< definitions.ShortTpe => Some(classOf[java.lang.Short])
+          case t if t <:< definitions.ByteTpe => Some(classOf[java.lang.Byte])
+          case t if t <:< definitions.BooleanTpe => Some(classOf[java.lang.Boolean])
+          case _ => None
+        }
+
+        boxedType.map { boxedType =>
+          val objectType = ObjectType(boxedType)
+          WrapOption(
+            objectType,
+            NewInstance(
+              boxedType,
+              getPath :: Nil,
+              propagateNull = true,
+              objectType))
+        }.getOrElse {
+          val className: String = optType.erasure.typeSymbol.asClass.fullName
+          val cls = Utils.classForName(className)
+          val objectType = ObjectType(cls)
+
+          WrapOption(objectType, constructorFor(optType, path))
+        }
+
+      case t if t <:< localTypeOf[java.lang.Integer] =>
+        val boxedType = classOf[java.lang.Integer]
+        val objectType = ObjectType(boxedType)
+        NewInstance(boxedType, getPath :: Nil, propagateNull = true, objectType)
+
+      case t if t <:< localTypeOf[java.lang.Long] =>
+        val boxedType = classOf[java.lang.Long]
+        val objectType = ObjectType(boxedType)
+        NewInstance(boxedType, getPath :: Nil, propagateNull = true, objectType)
+
+      case t if t <:< localTypeOf[java.lang.Double] =>
+        val boxedType = classOf[java.lang.Double]
+        val objectType = ObjectType(boxedType)
+        NewInstance(boxedType, getPath :: Nil, propagateNull = true, objectType)
+
+      case t if t <:< localTypeOf[java.lang.Float] =>
+        val boxedType = classOf[java.lang.Float]
+        val objectType = ObjectType(boxedType)
+        NewInstance(boxedType, getPath :: Nil, propagateNull = true, objectType)
+
+      case t if t <:< localTypeOf[java.lang.Short] =>
+        val boxedType = classOf[java.lang.Short]
+        val objectType = ObjectType(boxedType)
+        NewInstance(boxedType, getPath :: Nil, propagateNull = true, objectType)
+
+      case t if t <:< localTypeOf[java.lang.Byte] =>
+        val boxedType = classOf[java.lang.Byte]
+        val objectType = ObjectType(boxedType)
+        NewInstance(boxedType, getPath :: Nil, propagateNull = true, objectType)
+
+      case t if t <:< localTypeOf[java.lang.Boolean] =>
+        val boxedType = classOf[java.lang.Boolean]
+        val objectType = ObjectType(boxedType)
+        NewInstance(boxedType, getPath :: Nil, propagateNull = true, objectType)
+
+      case t if t <:< localTypeOf[java.sql.Date] =>
+        StaticInvoke(
+          DateTimeUtils,
+          ObjectType(classOf[java.sql.Date]),
+          "toJavaDate",
+          getPath :: Nil,
+          propagateNull = true)
+
+      case t if t <:< localTypeOf[java.sql.Timestamp] =>
+        StaticInvoke(
+          DateTimeUtils,
+          ObjectType(classOf[java.sql.Timestamp]),
+          "toJavaTimestamp",
+          getPath :: Nil,
+          propagateNull = true)
+
+      case t if t <:< localTypeOf[java.lang.String] =>
+        Invoke(getPath, "toString", ObjectType(classOf[String]))
+
+      case t if t <:< localTypeOf[java.math.BigDecimal] =>
+        Invoke(getPath, "toJavaBigDecimal", ObjectType(classOf[java.math.BigDecimal]))
+
+      case t if t <:< localTypeOf[Array[_]] =>
+        val TypeRef(_, _, Seq(elementType)) = t
+        val elementDataType = dataTypeFor(elementType)
+        val Schema(dataType, nullable) = schemaFor(elementType)
+
+        val primitiveMethod = elementType match {
+          case t if t <:< definitions.IntTpe => Some("toIntArray")
+          case t if t <:< definitions.LongTpe => Some("toLongArray")
+          case t if t <:< definitions.DoubleTpe => Some("toDoubleArray")
+          case t if t <:< definitions.FloatTpe => Some("toFloatArray")
+          case t if t <:< definitions.ShortTpe => Some("toShortArray")
+          case t if t <:< definitions.ByteTpe => Some("toByteArray")
+          case t if t <:< definitions.BooleanTpe => Some("toBooleanArray")
+          case _ => None
+        }
+
+        primitiveMethod.map { method =>
+          Invoke(getPath, method, dataTypeFor(t))
+        }.getOrElse {
+          val returnType = dataTypeFor(t)
+          Invoke(
+            MapObjects(p => constructorFor(elementType, Some(p)), getPath, dataType),
+            "array",
+            returnType)
+        }
+
+      case t if t <:< localTypeOf[Map[_, _]] =>
+        val TypeRef(_, _, Seq(keyType, valueType)) = t
+        val Schema(keyDataType, _) = schemaFor(keyType)
+        val Schema(valueDataType, valueNullable) = schemaFor(valueType)
+
+        val primitiveMethodKey = keyType match {
+          case t if t <:< definitions.IntTpe => Some("toIntArray")
+          case t if t <:< definitions.LongTpe => Some("toLongArray")
+          case t if t <:< definitions.DoubleTpe => Some("toDoubleArray")
+          case t if t <:< definitions.FloatTpe => Some("toFloatArray")
+          case t if t <:< definitions.ShortTpe => Some("toShortArray")
+          case t if t <:< definitions.ByteTpe => Some("toByteArray")
+          case t if t <:< definitions.BooleanTpe => Some("toBooleanArray")
+          case _ => None
+        }
+
+        val keyData =
+          Invoke(
+            MapObjects(
+              p => constructorFor(keyType, Some(p)),
+              Invoke(getPath, "keyArray", ArrayType(keyDataType)),
+              keyDataType),
+            "array",
+            ObjectType(classOf[Array[Any]]))
+
+        val primitiveMethodValue = valueType match {
+          case t if t <:< definitions.IntTpe => Some("toIntArray")
+          case t if t <:< definitions.LongTpe => Some("toLongArray")
+          case t if t <:< definitions.DoubleTpe => Some("toDoubleArray")
+          case t if t <:< definitions.FloatTpe => Some("toFloatArray")
+          case t if t <:< definitions.ShortTpe => Some("toShortArray")
+          case t if t <:< definitions.ByteTpe => Some("toByteArray")
+          case t if t <:< definitions.BooleanTpe => Some("toBooleanArray")
+          case _ => None
+        }
+
+        val valueData =
+          Invoke(
+            MapObjects(
+              p => constructorFor(valueType, Some(p)),
+              Invoke(getPath, "valueArray", ArrayType(valueDataType)),
+              valueDataType),
+            "array",
+            ObjectType(classOf[Array[Any]]))
+
+        StaticInvoke(
+          ArrayBasedMapData,
+          ObjectType(classOf[Map[_, _]]),
+          "toScalaMap",
+          keyData :: valueData :: Nil)
+
+      case t if t <:< localTypeOf[Seq[_]] =>
+        val TypeRef(_, _, Seq(elementType)) = t
+        val elementDataType = dataTypeFor(elementType)
+        val Schema(dataType, nullable) = schemaFor(elementType)
+
+        // Avoid boxing when possible by just wrapping a primitive array.
+        val primitiveMethod = elementType match {
+          case _ if nullable => None
+          case t if t <:< definitions.IntTpe => Some("toIntArray")
+          case t if t <:< definitions.LongTpe => Some("toLongArray")
+          case t if t <:< definitions.DoubleTpe => Some("toDoubleArray")
+          case t if t <:< definitions.FloatTpe => Some("toFloatArray")
+          case t if t <:< definitions.ShortTpe => Some("toShortArray")
+          case t if t <:< definitions.ByteTpe => Some("toByteArray")
+          case t if t <:< definitions.BooleanTpe => Some("toBooleanArray")
+          case _ => None
+        }
+
+        val arrayData = primitiveMethod.map { method =>
+          Invoke(getPath, method, arrayClassFor(elementType))
+        }.getOrElse {
+          Invoke(
+            MapObjects(p => constructorFor(elementType, Some(p)), getPath, dataType),
+            "array",
+            arrayClassFor(elementType))
+        }
+
+        StaticInvoke(
+          scala.collection.mutable.WrappedArray,
+          ObjectType(classOf[Seq[_]]),
+          "make",
+          arrayData :: Nil)
+
+
+      case t if t <:< localTypeOf[Product] =>
+        val formalTypeArgs = t.typeSymbol.asClass.typeParams
+        val TypeRef(_, _, actualTypeArgs) = t
+        val constructorSymbol = t.member(nme.CONSTRUCTOR)
+        val params = if (constructorSymbol.isMethod) {
+          constructorSymbol.asMethod.paramss
+        } else {
+          // Find the primary constructor, and use its parameter ordering.
+          val primaryConstructorSymbol: Option[Symbol] =
+            constructorSymbol.asTerm.alternatives.find(s =>
+              s.isMethod && s.asMethod.isPrimaryConstructor)
+
+          if (primaryConstructorSymbol.isEmpty) {
+            sys.error("Internal SQL error: Product object did not have a primary constructor.")
+          } else {
+            primaryConstructorSymbol.get.asMethod.paramss
+          }
+        }
+
+        val className: String = t.erasure.typeSymbol.asClass.fullName
+        val cls = Utils.classForName(className)
+
+        val arguments = params.head.map { p =>
+          val fieldName = p.name.toString
+          val fieldType = p.typeSignature.substituteTypes(formalTypeArgs, actualTypeArgs)
+          val dataType = dataTypeFor(fieldType)
+
+          constructorFor(fieldType, Some(addToPath(fieldName)))
+        }
+
+        val newInstance = NewInstance(cls, arguments, propagateNull = false, ObjectType(cls))
+
+        if (path.nonEmpty) {
+          expressions.If(
+            IsNull(getPath),
+            expressions.Literal.create(null, ObjectType(cls)),
+            newInstance
+          )
+        } else {
+          newInstance
+        }
+
+    }
+  }
+
   /** Returns expressions for extracting all the fields from the given type. */
   def extractorsFor[T : TypeTag](inputObject: Expression): Seq[Expression] = {
     ScalaReflectionLock.synchronized {
@@ -227,13 +523,13 @@ trait ScalaReflection {
           val elementDataType = dataTypeFor(elementType)
           val Schema(dataType, nullable) = schemaFor(elementType)
 
-          if (!elementDataType.isInstanceOf[AtomicType]) {
-            MapObjects(extractorFor(_, elementType), inputObject, elementDataType)
-          } else {
+          if (dataType.isInstanceOf[AtomicType]) {
             NewInstance(
               classOf[GenericArrayData],
               inputObject :: Nil,
               dataType = ArrayType(dataType, nullable))
+          } else {
+            MapObjects(extractorFor(_, elementType), inputObject, elementDataType)
           }
 
         case t if t <:< localTypeOf[Map[_, _]] =>

http://git-wip-us.apache.org/repos/asf/spark/blob/328d1b3e/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/Encoder.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/Encoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/Encoder.scala
index 8dacfa9..3618247 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/Encoder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/Encoder.scala
@@ -17,8 +17,10 @@
 
 package org.apache.spark.sql.catalyst.encoders
 
+
 import scala.reflect.ClassTag
 
+import org.apache.spark.sql.catalyst.expressions.Attribute
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.types.StructType
 
@@ -41,4 +43,16 @@ trait Encoder[T] {
    * copy the result before making another call if required.
    */
   def toRow(t: T): InternalRow
+
+  /**
+   * Returns an object of type `T`, extracting the required values from the provided row.  Note that
+   * you must bind` and encoder to a specific schema before you can call this function.
+   */
+  def fromRow(row: InternalRow): T
+
+  /**
+   * Returns a new copy of this encoder, where the expressions used by `fromRow` are bound to the
+   * given schema
+   */
+  def bind(schema: Seq[Attribute]): Encoder[T]
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/328d1b3e/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoder.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoder.scala
index a236136..b038188 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoder.scala
@@ -17,8 +17,10 @@
 
 package org.apache.spark.sql.catalyst.encoders
 
+import org.apache.spark.sql.catalyst.analysis.SimpleAnalyzer
 import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection
+import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateSafeProjection, GenerateUnsafeProjection}
+import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Project}
 
 import scala.reflect.ClassTag
 import scala.reflect.runtime.universe.{typeTag, TypeTag}
@@ -31,7 +33,7 @@ import org.apache.spark.sql.types.{ObjectType, StructType}
  * internal binary representation.
  */
 object ProductEncoder {
-  def apply[T <: Product : TypeTag]: Encoder[T] = {
+  def apply[T <: Product : TypeTag]: ClassEncoder[T] = {
     // We convert the not-serializable TypeTag into StructType and ClassTag.
     val schema = ScalaReflection.schemaFor[T].dataType.asInstanceOf[StructType]
     val mirror = typeTag[T].mirror
@@ -39,7 +41,8 @@ object ProductEncoder {
 
     val inputObject = BoundReference(0, ObjectType(cls), nullable = true)
     val extractExpressions = ScalaReflection.extractorsFor[T](inputObject)
-    new ClassEncoder[T](schema, extractExpressions, ClassTag[T](cls))
+    val constructExpression = ScalaReflection.constructorFor[T]
+    new ClassEncoder[T](schema, extractExpressions, constructExpression, ClassTag[T](cls))
   }
 }
 
@@ -54,14 +57,31 @@ object ProductEncoder {
 case class ClassEncoder[T](
     schema: StructType,
     extractExpressions: Seq[Expression],
+    constructExpression: Expression,
     clsTag: ClassTag[T])
   extends Encoder[T] {
 
   private val extractProjection = GenerateUnsafeProjection.generate(extractExpressions)
   private val inputRow = new GenericMutableRow(1)
 
+  private lazy val constructProjection = GenerateSafeProjection.generate(constructExpression :: Nil)
+  private val dataType = ObjectType(clsTag.runtimeClass)
+
   override def toRow(t: T): InternalRow = {
     inputRow(0) = t
     extractProjection(inputRow)
   }
+
+  override def fromRow(row: InternalRow): T = {
+    constructProjection(row).get(0, dataType).asInstanceOf[T]
+  }
+
+  override def bind(schema: Seq[Attribute]): ClassEncoder[T] = {
+    val plan = Project(Alias(constructExpression, "object")() :: Nil, LocalRelation(schema))
+    val analyzedPlan = SimpleAnalyzer.execute(plan)
+    val resolvedExpression = analyzedPlan.expressions.head.children.head
+    val boundExpression = BindReferences.bindReference(resolvedExpression, schema)
+
+    copy(constructExpression = boundExpression)
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/328d1b3e/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala
index e1f960a..e8c1c93 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala
@@ -17,9 +17,12 @@
 
 package org.apache.spark.sql.catalyst.expressions
 
+import org.apache.spark.sql.catalyst.analysis.SimpleAnalyzer
+import org.apache.spark.sql.catalyst.plans.logical.{Project, LocalRelation}
+
 import scala.language.existentials
 
-import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.{ScalaReflection, InternalRow}
 import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, CodeGenContext}
 import org.apache.spark.sql.types._
 
@@ -48,7 +51,7 @@ case class StaticInvoke(
     case other => other.getClass.getName.stripSuffix("$")
   }
   override def nullable: Boolean = true
-  override def children: Seq[Expression] = Nil
+  override def children: Seq[Expression] = arguments
 
   override def eval(input: InternalRow): Any =
     throw new UnsupportedOperationException("Only code-generated evaluation is supported.")
@@ -69,7 +72,7 @@ case class StaticInvoke(
       s"""
         ${argGen.map(_.code).mkString("\n")}
 
-        boolean ${ev.isNull} = true;
+        boolean ${ev.isNull} = !$argsNonNull;
         $javaType ${ev.value} = ${ctx.defaultValue(dataType)};
 
         if ($argsNonNull) {
@@ -81,8 +84,8 @@ case class StaticInvoke(
       s"""
         ${argGen.map(_.code).mkString("\n")}
 
-        final boolean ${ev.isNull} = ${ev.value} == null;
         $javaType ${ev.value} = $objectName.$functionName($argString);
+        final boolean ${ev.isNull} = ${ev.value} == null;
       """
     }
   }
@@ -92,6 +95,10 @@ case class StaticInvoke(
  * Calls the specified function on an object, optionally passing arguments.  If the `targetObject`
  * expression evaluates to null then null will be returned.
  *
+ * In some cases, due to erasure, the schema may expect a primitive type when in fact the method
+ * is returning java.lang.Object.  In this case, we will generate code that attempts to unbox the
+ * value automatically.
+ *
  * @param targetObject An expression that will return the object to call the method on.
  * @param functionName The name of the method to call.
  * @param dataType The expected return type of the function.
@@ -109,6 +116,35 @@ case class Invoke(
   override def eval(input: InternalRow): Any =
     throw new UnsupportedOperationException("Only code-generated evaluation is supported.")
 
+  lazy val method = targetObject.dataType match {
+    case ObjectType(cls) =>
+      cls
+        .getMethods
+        .find(_.getName == functionName)
+        .getOrElse(sys.error(s"Couldn't find $functionName on $cls"))
+        .getReturnType
+        .getName
+    case _ => ""
+  }
+
+  lazy val unboxer = (dataType, method) match {
+    case (IntegerType, "java.lang.Object") => (s: String) =>
+      s"((java.lang.Integer)$s).intValue()"
+    case (LongType, "java.lang.Object") => (s: String) =>
+      s"((java.lang.Long)$s).longValue()"
+    case (FloatType, "java.lang.Object") => (s: String) =>
+      s"((java.lang.Float)$s).floatValue()"
+    case (ShortType, "java.lang.Object") => (s: String) =>
+      s"((java.lang.Short)$s).shortValue()"
+    case (ByteType, "java.lang.Object") => (s: String) =>
+      s"((java.lang.Byte)$s).byteValue()"
+    case (DoubleType, "java.lang.Object") => (s: String) =>
+      s"((java.lang.Double)$s).doubleValue()"
+    case (BooleanType, "java.lang.Object") => (s: String) =>
+      s"((java.lang.Boolean)$s).booleanValue()"
+    case _ => identity[String] _
+  }
+
   override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
     val javaType = ctx.javaType(dataType)
     val obj = targetObject.gen(ctx)
@@ -123,6 +159,8 @@ case class Invoke(
       ""
     }
 
+    val value = unboxer(s"${obj.value}.$functionName($argString)")
+
     s"""
       ${obj.code}
       ${argGen.map(_.code).mkString("\n")}
@@ -130,7 +168,7 @@ case class Invoke(
       boolean ${ev.isNull} = ${obj.value} == null;
       $javaType ${ev.value} =
         ${ev.isNull} ?
-        ${ctx.defaultValue(dataType)} : ($javaType) ${obj.value}.$functionName($argString);
+        ${ctx.defaultValue(dataType)} : ($javaType) $value;
       $objNullCheck
     """
   }
@@ -190,8 +228,8 @@ case class NewInstance(
       s"""
         ${argGen.map(_.code).mkString("\n")}
 
-        final boolean ${ev.isNull} = ${ev.value} == null;
         $javaType ${ev.value} = new $className($argString);
+        final boolean ${ev.isNull} = ${ev.value} == null;
       """
     }
   }
@@ -210,8 +248,6 @@ case class UnwrapOption(
 
   override def nullable: Boolean = true
 
-  override def children: Seq[Expression] = Nil
-
   override def inputTypes: Seq[AbstractDataType] = ObjectType :: Nil
 
   override def eval(input: InternalRow): Any =
@@ -231,6 +267,43 @@ case class UnwrapOption(
   }
 }
 
+/**
+ * Converts the result of evaluating `child` into an option, checking both the isNull bit and
+ * (in the case of reference types) equality with null.
+ * @param optionType The datatype to be held inside of the Option.
+ * @param child The expression to evaluate and wrap.
+ */
+case class WrapOption(optionType: DataType, child: Expression)
+  extends UnaryExpression with ExpectsInputTypes {
+
+  override def dataType: DataType = ObjectType(classOf[Option[_]])
+
+  override def nullable: Boolean = true
+
+  override def inputTypes: Seq[AbstractDataType] = ObjectType :: Nil
+
+  override def eval(input: InternalRow): Any =
+    throw new UnsupportedOperationException("Only code-generated evaluation is supported")
+
+  override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
+    val javaType = ctx.javaType(optionType)
+    val inputObject = child.gen(ctx)
+
+    s"""
+      ${inputObject.code}
+
+      boolean ${ev.isNull} = false;
+      scala.Option<$javaType> ${ev.value} =
+        ${inputObject.isNull} ?
+        scala.Option$$.MODULE$$.apply(null) : new scala.Some(${inputObject.value});
+    """
+  }
+}
+
+/**
+ * A place holder for the loop variable used in [[MapObjects]].  This should never be constructed
+ * manually, but will instead be passed into the provided lambda function.
+ */
 case class LambdaVariable(value: String, isNull: String, dataType: DataType) extends Expression {
 
   override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String =
@@ -251,7 +324,7 @@ case class LambdaVariable(value: String, isNull: String, dataType: DataType) ext
  * as an ArrayType.  This is similar to a typical map operation, but where the lambda function
  * is expressed using catalyst expressions.
  *
- * The following collection ObjectTypes are currently supported: Seq, Array
+ * The following collection ObjectTypes are currently supported: Seq, Array, ArrayData
  *
  * @param function A function that returns an expression, given an attribute that can be used
  *                 to access the current value.  This is does as a lambda function so that
@@ -265,14 +338,32 @@ case class MapObjects(
     inputData: Expression,
     elementType: DataType) extends Expression {
 
-  private val loopAttribute = AttributeReference("loopVar", elementType)()
-  private val completeFunction = function(loopAttribute)
+  private lazy val loopAttribute = AttributeReference("loopVar", elementType)()
+  private lazy val completeFunction = function(loopAttribute)
 
-  private val (lengthFunction, itemAccessor) = inputData.dataType match {
-    case ObjectType(cls) if cls.isAssignableFrom(classOf[Seq[_]]) =>
-      (".size()", (i: String) => s".apply($i)")
+  private lazy val (lengthFunction, itemAccessor, primitiveElement) = inputData.dataType match {
+    case ObjectType(cls) if classOf[Seq[_]].isAssignableFrom(cls) =>
+      (".size()", (i: String) => s".apply($i)", false)
     case ObjectType(cls) if cls.isArray =>
-      (".length", (i: String) => s"[$i]")
+      (".length", (i: String) => s"[$i]", false)
+    case ArrayType(s: StructType, _) =>
+      (".numElements()", (i: String) => s".getStruct($i, ${s.size})", false)
+    case ArrayType(a: ArrayType, _) =>
+      (".numElements()", (i: String) => s".getArray($i)", true)
+    case ArrayType(IntegerType, _) =>
+      (".numElements()", (i: String) => s".getInt($i)", true)
+    case ArrayType(LongType, _) =>
+      (".numElements()", (i: String) => s".getLong($i)", true)
+    case ArrayType(FloatType, _) =>
+      (".numElements()", (i: String) => s".getFloat($i)", true)
+    case ArrayType(DoubleType, _) =>
+      (".numElements()", (i: String) => s".getDouble($i)", true)
+    case ArrayType(ByteType, _) =>
+      (".numElements()", (i: String) => s".getByte($i)", true)
+    case ArrayType(ShortType, _) =>
+      (".numElements()", (i: String) => s".getShort($i)", true)
+    case ArrayType(BooleanType, _) =>
+      (".numElements()", (i: String) => s".getBoolean($i)", true)
   }
 
   override def nullable: Boolean = true
@@ -294,15 +385,38 @@ case class MapObjects(
     val loopIsNull = ctx.freshName("loopIsNull")
 
     val loopVariable = LambdaVariable(loopValue, loopIsNull, elementType)
-    val boundFunction = completeFunction transform {
+    val substitutedFunction = completeFunction transform {
       case a: AttributeReference if a == loopAttribute => loopVariable
     }
+    // A hack to run this through the analyzer (to bind extractions).
+    val boundFunction =
+      SimpleAnalyzer.execute(Project(Alias(substitutedFunction, "")() :: Nil, LocalRelation(Nil)))
+        .expressions.head.children.head
 
     val genFunction = boundFunction.gen(ctx)
     val dataLength = ctx.freshName("dataLength")
     val convertedArray = ctx.freshName("convertedArray")
     val loopIndex = ctx.freshName("loopIndex")
 
+    val convertedType = ctx.javaType(boundFunction.dataType)
+
+    // Because of the way Java defines nested arrays, we have to handle the syntax specially.
+    // Specifically, we have to insert the [$dataLength] in between the type and any extra nested
+    // array declarations (i.e. new String[1][]).
+    val arrayConstructor = if (convertedType contains "[]") {
+      val rawType = convertedType.takeWhile(_ != '[')
+      val arrayPart = convertedType.reverse.takeWhile(c => c == '[' || c == ']').reverse
+      s"new $rawType[$dataLength]$arrayPart"
+    } else {
+      s"new $convertedType[$dataLength]"
+    }
+
+    val loopNullCheck = if (primitiveElement) {
+      s"boolean $loopIsNull = ${genInputData.value}.isNullAt($loopIndex);"
+    } else {
+      s"boolean $loopIsNull = ${genInputData.isNull} || $loopValue == null;"
+    }
+
     s"""
       ${genInputData.code}
 
@@ -310,19 +424,19 @@ case class MapObjects(
       $javaType ${ev.value} = ${ctx.defaultValue(dataType)};
 
       if (!${ev.isNull}) {
-        Object[] $convertedArray = null;
+        $convertedType[] $convertedArray = null;
         int $dataLength = ${genInputData.value}$lengthFunction;
-        $convertedArray = new Object[$dataLength];
+        $convertedArray = $arrayConstructor;
 
         int $loopIndex = 0;
         while ($loopIndex < $dataLength) {
           $elementJavaType $loopValue =
             ($elementJavaType)${genInputData.value}${itemAccessor(loopIndex)};
-          boolean $loopIsNull = $loopValue == null;
+          $loopNullCheck
 
           ${genFunction.code}
 
-          $convertedArray[$loopIndex] = ${genFunction.value};
+          $convertedArray[$loopIndex] = ($convertedType)${genFunction.value};
           $loopIndex += 1;
         }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/328d1b3e/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayBasedMapData.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayBasedMapData.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayBasedMapData.scala
index 5206959..5f22e59 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayBasedMapData.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayBasedMapData.scala
@@ -62,4 +62,8 @@ object ArrayBasedMapData {
     val values = map.valueArray.asInstanceOf[GenericArrayData].array
     keys.zip(values).toMap
   }
+
+  def toScalaMap(keys: Array[Any], values: Array[Any]): Map[Any, Any] = {
+    keys.zip(values).toMap
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/328d1b3e/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayData.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayData.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayData.scala
index 642c56f..b4ea300 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayData.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayData.scala
@@ -26,6 +26,8 @@ abstract class ArrayData extends SpecializedGetters with Serializable {
 
   def copy(): ArrayData
 
+  def array: Array[Any]
+
   def toBooleanArray(): Array[Boolean] = {
     val size = numElements()
     val values = new Array[Boolean](size)
@@ -103,6 +105,9 @@ abstract class ArrayData extends SpecializedGetters with Serializable {
     values
   }
 
+  def toObjectArray(elementType: DataType): Array[AnyRef] =
+    toArray[AnyRef](elementType: DataType)
+
   def toArray[T: ClassTag](elementType: DataType): Array[T] = {
     val size = numElements()
     val values = new Array[T](size)

http://git-wip-us.apache.org/repos/asf/spark/blob/328d1b3e/sql/catalyst/src/main/scala/org/apache/spark/sql/types/GenericArrayData.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/GenericArrayData.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/GenericArrayData.scala
index c381603..9448d88 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/GenericArrayData.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/GenericArrayData.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql.types
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
 
-class GenericArrayData(private[sql] val array: Array[Any]) extends ArrayData {
+class GenericArrayData(val array: Array[Any]) extends ArrayData {
 
   def this(seq: scala.collection.GenIterable[Any]) = this(seq.toArray)
 
@@ -29,6 +29,8 @@ class GenericArrayData(private[sql] val array: Array[Any]) extends ArrayData {
   def this(primitiveArray: Array[Long]) = this(primitiveArray.toSeq)
   def this(primitiveArray: Array[Float]) = this(primitiveArray.toSeq)
   def this(primitiveArray: Array[Double]) = this(primitiveArray.toSeq)
+  def this(primitiveArray: Array[Short]) = this(primitiveArray.toSeq)
+  def this(primitiveArray: Array[Byte]) = this(primitiveArray.toSeq)
   def this(primitiveArray: Array[Boolean]) = this(primitiveArray.toSeq)
 
   override def copy(): ArrayData = new GenericArrayData(array.clone())

http://git-wip-us.apache.org/repos/asf/spark/blob/328d1b3e/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoderSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoderSuite.scala
index 99c993d..02e43dd 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoderSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoderSuite.scala
@@ -17,158 +17,263 @@
 
 package org.apache.spark.sql.catalyst.encoders
 
-import java.sql.{Date, Timestamp}
+import java.util
+
+import org.apache.spark.sql.types.{StructField, ArrayType, ArrayData}
+
+import scala.collection.mutable.ArrayBuffer
+import scala.reflect.runtime.universe._
 
 import org.apache.spark.SparkFunSuite
-import org.apache.spark.sql.catalyst.ScalaReflection._
-import org.apache.spark.sql.catalyst.expressions.UnsafeProjection
 import org.apache.spark.sql.catalyst._
 
-
 case class RepeatedStruct(s: Seq[PrimitiveData])
 
 case class NestedArray(a: Array[Array[Int]])
 
-class ProductEncoderSuite extends SparkFunSuite {
+case class BoxedData(
+    intField: java.lang.Integer,
+    longField: java.lang.Long,
+    doubleField: java.lang.Double,
+    floatField: java.lang.Float,
+    shortField: java.lang.Short,
+    byteField: java.lang.Byte,
+    booleanField: java.lang.Boolean)
 
-  test("convert PrimitiveData to InternalRow") {
-    val inputData = PrimitiveData(1, 1, 1, 1, 1, 1, true)
-    val encoder = ProductEncoder[PrimitiveData]
-    val convertedData = encoder.toRow(inputData)
-
-    assert(convertedData.getInt(0) == 1)
-    assert(convertedData.getLong(1) == 1.toLong)
-    assert(convertedData.getDouble(2) == 1.toDouble)
-    assert(convertedData.getFloat(3) == 1.toFloat)
-    assert(convertedData.getShort(4) == 1.toShort)
-    assert(convertedData.getByte(5) == 1.toByte)
-    assert(convertedData.getBoolean(6) == true)
-  }
+case class RepeatedData(
+    arrayField: Seq[Int],
+    arrayFieldContainsNull: Seq[java.lang.Integer],
+    mapField: scala.collection.Map[Int, Long],
+    mapFieldNull: scala.collection.Map[Int, java.lang.Long],
+    structField: PrimitiveData)
 
-  test("convert Some[_] to InternalRow") {
-    val primitiveData = PrimitiveData(1, 1, 1, 1, 1, 1, true)
-    val inputData = OptionalData(Some(2), Some(2), Some(2), Some(2), Some(2), Some(2), Some(true),
-      Some(primitiveData))
-
-    val encoder = ProductEncoder[OptionalData]
-    val convertedData = encoder.toRow(inputData)
-
-    assert(convertedData.getInt(0) == 2)
-    assert(convertedData.getLong(1) == 2.toLong)
-    assert(convertedData.getDouble(2) == 2.toDouble)
-    assert(convertedData.getFloat(3) == 2.toFloat)
-    assert(convertedData.getShort(4) == 2.toShort)
-    assert(convertedData.getByte(5) == 2.toByte)
-    assert(convertedData.getBoolean(6) == true)
-
-    val nestedRow = convertedData.getStruct(7, 7)
-    assert(nestedRow.getInt(0) == 1)
-    assert(nestedRow.getLong(1) == 1.toLong)
-    assert(nestedRow.getDouble(2) == 1.toDouble)
-    assert(nestedRow.getFloat(3) == 1.toFloat)
-    assert(nestedRow.getShort(4) == 1.toShort)
-    assert(nestedRow.getByte(5) == 1.toByte)
-    assert(nestedRow.getBoolean(6) == true)
-  }
+case class SpecificCollection(l: List[Int])
 
-  test("convert None to InternalRow") {
-    val inputData = OptionalData(None, None, None, None, None, None, None, None)
-    val encoder = ProductEncoder[OptionalData]
-    val convertedData = encoder.toRow(inputData)
-
-    assert(convertedData.isNullAt(0))
-    assert(convertedData.isNullAt(1))
-    assert(convertedData.isNullAt(2))
-    assert(convertedData.isNullAt(3))
-    assert(convertedData.isNullAt(4))
-    assert(convertedData.isNullAt(5))
-    assert(convertedData.isNullAt(6))
-    assert(convertedData.isNullAt(7))
-  }
+class ProductEncoderSuite extends SparkFunSuite {
 
-  test("convert nullable but present data to InternalRow") {
-    val inputData = NullableData(
-      1, 1L, 1.0, 1.0f, 1.toShort, 1.toByte, true, "test", new java.math.BigDecimal(1), new Date(0),
-      new Timestamp(0), Array[Byte](1, 2, 3))
-
-    val encoder = ProductEncoder[NullableData]
-    val convertedData = encoder.toRow(inputData)
-
-    assert(convertedData.getInt(0) == 1)
-    assert(convertedData.getLong(1) == 1.toLong)
-    assert(convertedData.getDouble(2) == 1.toDouble)
-    assert(convertedData.getFloat(3) == 1.toFloat)
-    assert(convertedData.getShort(4) == 1.toShort)
-    assert(convertedData.getByte(5) == 1.toByte)
-    assert(convertedData.getBoolean(6) == true)
-  }
+  encodeDecodeTest(PrimitiveData(1, 1, 1, 1, 1, 1, true))
 
-  test("convert nullable data to InternalRow") {
-    val inputData =
-      NullableData(null, null, null, null, null, null, null, null, null, null, null, null)
-
-    val encoder = ProductEncoder[NullableData]
-    val convertedData = encoder.toRow(inputData)
-
-    assert(convertedData.isNullAt(0))
-    assert(convertedData.isNullAt(1))
-    assert(convertedData.isNullAt(2))
-    assert(convertedData.isNullAt(3))
-    assert(convertedData.isNullAt(4))
-    assert(convertedData.isNullAt(5))
-    assert(convertedData.isNullAt(6))
-    assert(convertedData.isNullAt(7))
-    assert(convertedData.isNullAt(8))
-    assert(convertedData.isNullAt(9))
-    assert(convertedData.isNullAt(10))
-    assert(convertedData.isNullAt(11))
-  }
+  // TODO: Support creating specific subclasses of Seq.
+  ignore("Specific collection types") { encodeDecodeTest(SpecificCollection(1 :: Nil)) }
 
-  test("convert repeated struct") {
-    val inputData = RepeatedStruct(PrimitiveData(1, 1, 1, 1, 1, 1, true) :: Nil)
-    val encoder = ProductEncoder[RepeatedStruct]
-
-    val converted = encoder.toRow(inputData)
-    val convertedStruct = converted.getArray(0).getStruct(0, 7)
-    assert(convertedStruct.getInt(0) == 1)
-    assert(convertedStruct.getLong(1) == 1.toLong)
-    assert(convertedStruct.getDouble(2) == 1.toDouble)
-    assert(convertedStruct.getFloat(3) == 1.toFloat)
-    assert(convertedStruct.getShort(4) == 1.toShort)
-    assert(convertedStruct.getByte(5) == 1.toByte)
-    assert(convertedStruct.getBoolean(6) == true)
-  }
+  encodeDecodeTest(
+    OptionalData(
+      Some(2), Some(2), Some(2), Some(2), Some(2), Some(2), Some(true),
+      Some(PrimitiveData(1, 1, 1, 1, 1, 1, true))))
 
-  test("convert nested seq") {
-    val convertedData = ProductEncoder[Tuple1[Seq[Seq[Int]]]].toRow(Tuple1(Seq(Seq(1))))
-    assert(convertedData.getArray(0).getArray(0).getInt(0) == 1)
+  encodeDecodeTest(OptionalData(None, None, None, None, None, None, None, None))
 
-    val convertedData2 = ProductEncoder[Tuple1[Seq[Seq[Seq[Int]]]]].toRow(Tuple1(Seq(Seq(Seq(1)))))
-    assert(convertedData2.getArray(0).getArray(0).getArray(0).getInt(0) == 1)
-  }
+  encodeDecodeTest(
+    BoxedData(1, 1L, 1.0, 1.0f, 1.toShort, 1.toByte, true))
 
-  test("convert nested array") {
-    val convertedData = ProductEncoder[Tuple1[Array[Array[Int]]]].toRow(Tuple1(Array(Array(1))))
-  }
+  encodeDecodeTest(
+    BoxedData(null, null, null, null, null, null, null))
+
+  encodeDecodeTest(
+    RepeatedStruct(PrimitiveData(1, 1, 1, 1, 1, 1, true) :: Nil))
 
-  test("convert complex") {
-    val inputData = ComplexData(
+  encodeDecodeTest(
+    RepeatedData(
       Seq(1, 2),
-      Array(1, 2),
-      1 :: 2 :: Nil,
       Seq(new Integer(1), null, new Integer(2)),
       Map(1 -> 2L),
-      Map(1 -> new java.lang.Long(2)),
-      PrimitiveData(1, 1, 1, 1, 1, 1, true),
-      Array(Array(1)))
-
-    val encoder = ProductEncoder[ComplexData]
-    val convertedData = encoder.toRow(inputData)
-
-    assert(!convertedData.isNullAt(0))
-    val seq = convertedData.getArray(0)
-    assert(seq.numElements() == 2)
-    assert(seq.getInt(0) == 1)
-    assert(seq.getInt(1) == 2)
+      Map(1 -> null),
+      PrimitiveData(1, 1, 1, 1, 1, 1, true)))
+
+  encodeDecodeTest(("nullable Seq[Integer]", Seq[Integer](1, null)))
+
+  encodeDecodeTest(("Seq[(String, String)]",
+    Seq(("a", "b"))))
+  encodeDecodeTest(("Seq[(Int, Int)]",
+    Seq((1, 2))))
+  encodeDecodeTest(("Seq[(Long, Long)]",
+    Seq((1L, 2L))))
+  encodeDecodeTest(("Seq[(Float, Float)]",
+    Seq((1.toFloat, 2.toFloat))))
+  encodeDecodeTest(("Seq[(Double, Double)]",
+    Seq((1.toDouble, 2.toDouble))))
+  encodeDecodeTest(("Seq[(Short, Short)]",
+    Seq((1.toShort, 2.toShort))))
+  encodeDecodeTest(("Seq[(Byte, Byte)]",
+    Seq((1.toByte, 2.toByte))))
+  encodeDecodeTest(("Seq[(Boolean, Boolean)]",
+    Seq((true, false))))
+
+  // TODO: Decoding/encoding of complex maps.
+  ignore("complex maps") {
+    encodeDecodeTest(("Map[Int, (String, String)]",
+      Map(1 ->("a", "b"))))
+  }
+
+  encodeDecodeTest(("ArrayBuffer[(String, String)]",
+    ArrayBuffer(("a", "b"))))
+  encodeDecodeTest(("ArrayBuffer[(Int, Int)]",
+    ArrayBuffer((1, 2))))
+  encodeDecodeTest(("ArrayBuffer[(Long, Long)]",
+    ArrayBuffer((1L, 2L))))
+  encodeDecodeTest(("ArrayBuffer[(Float, Float)]",
+    ArrayBuffer((1.toFloat, 2.toFloat))))
+  encodeDecodeTest(("ArrayBuffer[(Double, Double)]",
+    ArrayBuffer((1.toDouble, 2.toDouble))))
+  encodeDecodeTest(("ArrayBuffer[(Short, Short)]",
+    ArrayBuffer((1.toShort, 2.toShort))))
+  encodeDecodeTest(("ArrayBuffer[(Byte, Byte)]",
+    ArrayBuffer((1.toByte, 2.toByte))))
+  encodeDecodeTest(("ArrayBuffer[(Boolean, Boolean)]",
+    ArrayBuffer((true, false))))
+
+  encodeDecodeTest(("Seq[Seq[(Int, Int)]]",
+    Seq(Seq((1, 2)))))
+
+  encodeDecodeTestCustom(("Array[Array[(Int, Int)]]",
+    Array(Array((1, 2)))))
+  { (l, r) => l._2(0)(0) == r._2(0)(0) }
+
+  encodeDecodeTestCustom(("Array[Array[(Int, Int)]]",
+    Array(Array(Array((1, 2))))))
+  { (l, r) => l._2(0)(0)(0) == r._2(0)(0)(0) }
+
+  encodeDecodeTestCustom(("Array[Array[Array[(Int, Int)]]]",
+    Array(Array(Array(Array((1, 2)))))))
+  { (l, r) => l._2(0)(0)(0)(0) == r._2(0)(0)(0)(0) }
+
+  encodeDecodeTestCustom(("Array[Array[Array[Array[(Int, Int)]]]]",
+    Array(Array(Array(Array(Array((1, 2))))))))
+  { (l, r) => l._2(0)(0)(0)(0)(0) == r._2(0)(0)(0)(0)(0) }
+
+
+  encodeDecodeTestCustom(("Array[Array[Integer]]",
+    Array(Array[Integer](1))))
+  { (l, r) => l._2(0)(0) == r._2(0)(0) }
+
+  encodeDecodeTestCustom(("Array[Array[Int]]",
+    Array(Array(1))))
+  { (l, r) => l._2(0)(0) == r._2(0)(0) }
+
+  encodeDecodeTestCustom(("Array[Array[Int]]",
+    Array(Array(Array(1)))))
+  { (l, r) => l._2(0)(0)(0) == r._2(0)(0)(0) }
+
+  encodeDecodeTestCustom(("Array[Array[Array[Int]]]",
+    Array(Array(Array(Array(1))))))
+  { (l, r) => l._2(0)(0)(0)(0) == r._2(0)(0)(0)(0) }
+
+  encodeDecodeTestCustom(("Array[Array[Array[Array[Int]]]]",
+    Array(Array(Array(Array(Array(1)))))))
+  { (l, r) => l._2(0)(0)(0)(0)(0) == r._2(0)(0)(0)(0)(0) }
+
+  encodeDecodeTest(("Array[Byte] null",
+    null: Array[Byte]))
+  encodeDecodeTestCustom(("Array[Byte]",
+    Array[Byte](1, 2, 3)))
+    { (l, r) => util.Arrays.equals(l._2, r._2) }
+
+  encodeDecodeTest(("Array[Int] null",
+    null: Array[Int]))
+  encodeDecodeTestCustom(("Array[Int]",
+    Array[Int](1, 2, 3)))
+    { (l, r) => util.Arrays.equals(l._2, r._2) }
+
+  encodeDecodeTest(("Array[Long] null",
+    null: Array[Long]))
+  encodeDecodeTestCustom(("Array[Long]",
+    Array[Long](1, 2, 3)))
+    { (l, r) => util.Arrays.equals(l._2, r._2) }
+
+  encodeDecodeTest(("Array[Double] null",
+    null: Array[Double]))
+  encodeDecodeTestCustom(("Array[Double]",
+    Array[Double](1, 2, 3)))
+    { (l, r) => util.Arrays.equals(l._2, r._2) }
+
+  encodeDecodeTest(("Array[Float] null",
+    null: Array[Float]))
+  encodeDecodeTestCustom(("Array[Float]",
+    Array[Float](1, 2, 3)))
+    { (l, r) => util.Arrays.equals(l._2, r._2) }
+
+  encodeDecodeTest(("Array[Boolean] null",
+    null: Array[Boolean]))
+  encodeDecodeTestCustom(("Array[Boolean]",
+    Array[Boolean](true, false)))
+    { (l, r) => util.Arrays.equals(l._2, r._2) }
+
+  encodeDecodeTest(("Array[Short] null",
+    null: Array[Short]))
+  encodeDecodeTestCustom(("Array[Short]",
+    Array[Short](1, 2, 3)))
+    { (l, r) => util.Arrays.equals(l._2, r._2) }
+
+  encodeDecodeTestCustom(("java.sql.Timestamp",
+    new java.sql.Timestamp(1)))
+    { (l, r) => l._2.toString == r._2.toString }
+
+  encodeDecodeTestCustom(("java.sql.Date", new java.sql.Date(1)))
+    { (l, r) => l._2.toString == r._2.toString }
+
+  /** Simplified encodeDecodeTestCustom, where the comparison function can be `Object.equals`. */
+  protected def encodeDecodeTest[T <: Product : TypeTag](inputData: T) =
+    encodeDecodeTestCustom[T](inputData)((l, r) => l == r)
+
+  /**
+   * Constructs a test that round-trips `t` through an encoder, checking the results to ensure it
+   * matches the original.
+   */
+  protected def encodeDecodeTestCustom[T <: Product : TypeTag](
+      inputData: T)(
+      c: (T, T) => Boolean) = {
+    test(s"encode/decode: $inputData") {
+      val encoder = try ProductEncoder[T] catch {
+        case e: Exception =>
+          fail(s"Exception thrown generating encoder", e)
+      }
+      val convertedData = encoder.toRow(inputData)
+      val schema = encoder.schema.toAttributes
+      val boundEncoder = encoder.bind(schema)
+      val convertedBack = try boundEncoder.fromRow(convertedData) catch {
+        case e: Exception =>
+          fail(
+           s"""Exception thrown while decoding
+              |Converted: $convertedData
+              |Schema: ${schema.mkString(",")}
+              |${encoder.schema.treeString}
+              |
+              |Construct Expressions:
+              |${boundEncoder.constructExpression.treeString}
+              |
+            """.stripMargin, e)
+      }
+
+      if (!c(inputData, convertedBack)) {
+        val types =
+          convertedBack.productIterator.filter(_ != null).map(_.getClass.getName).mkString(",")
+
+        val encodedData = convertedData.toSeq(encoder.schema).zip(encoder.schema).map {
+          case (a: ArrayData, StructField(_, at: ArrayType, _, _)) =>
+            a.toArray[Any](at.elementType).toSeq
+          case (other, _) =>
+            other
+        }.mkString("[", ",", "]")
+
+        fail(
+          s"""Encoded/Decoded data does not match input data
+             |
+             |in:  $inputData
+             |out: $convertedBack
+             |types: $types
+             |
+             |Encoded Data: $encodedData
+             |Schema: ${schema.mkString(",")}
+             |${encoder.schema.treeString}
+             |
+             |Extract Expressions:
+             |${boundEncoder.extractExpressions.map(_.treeString).mkString("\n")}
+             |
+             |Construct Expressions:
+             |${boundEncoder.constructExpression.treeString}
+             |
+           """.stripMargin)
+      }
+    }
   }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org