You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ma...@apache.org on 2015/10/21 20:06:37 UTC

spark git commit: [SPARK-11216] [SQL] add encoder/decoder for external row

Repository: spark
Updated Branches:
  refs/heads/master f62e32608 -> ccf536f90


[SPARK-11216] [SQL] add encoder/decoder for external row

Implement encode/decode for external row based on `ClassEncoder`.

TODO:
* code cleanup
* ~~fix corner cases~~
* refactor the encoder interface
* improve test for product codegen, to cover more corner cases.

Author: Wenchen Fan <we...@databricks.com>

Closes #9184 from cloud-fan/encoder.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/ccf536f9
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/ccf536f9
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/ccf536f9

Branch: refs/heads/master
Commit: ccf536f903ef1f81fb3e1b6ce781d5e40d0ae3e0
Parents: f62e326
Author: Wenchen Fan <we...@databricks.com>
Authored: Wed Oct 21 11:06:34 2015 -0700
Committer: Michael Armbrust <mi...@databricks.com>
Committed: Wed Oct 21 11:06:34 2015 -0700

----------------------------------------------------------------------
 .../spark/sql/catalyst/ScalaReflection.scala    |   6 +-
 .../sql/catalyst/encoders/ClassEncoder.scala    |  75 ++++++
 .../spark/sql/catalyst/encoders/Encoder.scala   |   2 +-
 .../sql/catalyst/encoders/ProductEncoder.scala  |  46 +---
 .../sql/catalyst/encoders/RowEncoder.scala      | 234 +++++++++++++++++++
 .../sql/catalyst/expressions/objects.scala      |  46 +++-
 .../spark/sql/types/ArrayBasedMapData.scala     |   4 +
 .../apache/spark/sql/RandomDataGenerator.scala  |   4 +-
 .../sql/catalyst/encoders/RowEncoderSuite.scala |  96 ++++++++
 9 files changed, 459 insertions(+), 54 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/ccf536f9/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
index 8edd649..27c96f4 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
@@ -19,11 +19,11 @@ package org.apache.spark.sql.catalyst
 
 import org.apache.spark.sql.catalyst.analysis.{UnresolvedExtractValue, UnresolvedAttribute}
 import org.apache.spark.sql.catalyst.util.DateTimeUtils
-import org.apache.spark.unsafe.types.UTF8String
-import org.apache.spark.util.Utils
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
 import org.apache.spark.sql.types._
+import org.apache.spark.unsafe.types.UTF8String
+import org.apache.spark.util.Utils
 
 /**
  * A default version of ScalaReflection that uses the runtime universe.
@@ -142,7 +142,7 @@ trait ScalaReflection {
   }
 
   /**
-   * Returns an expression that can be used to construct an object of type `T` given a an input
+   * Returns an expression that can be used to construct an object of type `T` given an input
    * row with a compatible schema.  Fields of the row will be extracted using UnresolvedAttributes
    * of the same name as the constructor arguments.  Nested classes will have their fields accessed
    * using UnresolvedExtractValue.

http://git-wip-us.apache.org/repos/asf/spark/blob/ccf536f9/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ClassEncoder.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ClassEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ClassEncoder.scala
new file mode 100644
index 0000000..f3a1063
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ClassEncoder.scala
@@ -0,0 +1,75 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.encoders
+
+import scala.reflect.ClassTag
+
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.analysis.SimpleAnalyzer
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateSafeProjection, GenerateUnsafeProjection}
+import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Project}
+import org.apache.spark.sql.types.{ObjectType, StructType}
+
+/**
+ * A generic encoder for JVM objects.
+ *
+ * @param schema The schema after converting `T` to a Spark SQL row.
+ * @param extractExpressions A set of expressions, one for each top-level field that can be used to
+ *                           extract the values from a raw object.
+ * @param clsTag A classtag for `T`.
+ */
+case class ClassEncoder[T](
+    schema: StructType,
+    extractExpressions: Seq[Expression],
+    constructExpression: Expression,
+    clsTag: ClassTag[T])
+  extends Encoder[T] {
+
+  private val extractProjection = GenerateUnsafeProjection.generate(extractExpressions)
+  private val inputRow = new GenericMutableRow(1)
+
+  private lazy val constructProjection = GenerateSafeProjection.generate(constructExpression :: Nil)
+  private val dataType = ObjectType(clsTag.runtimeClass)
+
+  override def toRow(t: T): InternalRow = {
+    if (t == null) {
+      null
+    } else {
+      inputRow(0) = t
+      extractProjection(inputRow)
+    }
+  }
+
+  override def fromRow(row: InternalRow): T = {
+    if (row eq null) {
+      null.asInstanceOf[T]
+    } else {
+      constructProjection(row).get(0, dataType).asInstanceOf[T]
+    }
+  }
+
+  override def bind(schema: Seq[Attribute]): ClassEncoder[T] = {
+    val plan = Project(Alias(constructExpression, "object")() :: Nil, LocalRelation(schema))
+    val analyzedPlan = SimpleAnalyzer.execute(plan)
+    val resolvedExpression = analyzedPlan.expressions.head.children.head
+    val boundExpression = BindReferences.bindReference(resolvedExpression, schema)
+
+    copy(constructExpression = boundExpression)
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/ccf536f9/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/Encoder.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/Encoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/Encoder.scala
index 3618247..bdb1c09 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/Encoder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/Encoder.scala
@@ -46,7 +46,7 @@ trait Encoder[T] {
 
   /**
    * Returns an object of type `T`, extracting the required values from the provided row.  Note that
-   * you must bind` and encoder to a specific schema before you can call this function.
+   * you must bind the encoder to a specific schema before you can call this function.
    */
   def fromRow(row: InternalRow): T
 

http://git-wip-us.apache.org/repos/asf/spark/blob/ccf536f9/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoder.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoder.scala
index b038188..4f7ce45 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoder.scala
@@ -17,15 +17,11 @@
 
 package org.apache.spark.sql.catalyst.encoders
 
-import org.apache.spark.sql.catalyst.analysis.SimpleAnalyzer
-import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateSafeProjection, GenerateUnsafeProjection}
-import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Project}
-
 import scala.reflect.ClassTag
 import scala.reflect.runtime.universe.{typeTag, TypeTag}
 
-import org.apache.spark.sql.catalyst.{ScalaReflection, InternalRow}
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.ScalaReflection
 import org.apache.spark.sql.types.{ObjectType, StructType}
 
 /**
@@ -44,44 +40,6 @@ object ProductEncoder {
     val constructExpression = ScalaReflection.constructorFor[T]
     new ClassEncoder[T](schema, extractExpressions, constructExpression, ClassTag[T](cls))
   }
-}
-
-/**
- * A generic encoder for JVM objects.
- *
- * @param schema The schema after converting `T` to a Spark SQL row.
- * @param extractExpressions A set of expressions, one for each top-level field that can be used to
- *                           extract the values from a raw object.
- * @param clsTag A classtag for `T`.
- */
-case class ClassEncoder[T](
-    schema: StructType,
-    extractExpressions: Seq[Expression],
-    constructExpression: Expression,
-    clsTag: ClassTag[T])
-  extends Encoder[T] {
 
-  private val extractProjection = GenerateUnsafeProjection.generate(extractExpressions)
-  private val inputRow = new GenericMutableRow(1)
 
-  private lazy val constructProjection = GenerateSafeProjection.generate(constructExpression :: Nil)
-  private val dataType = ObjectType(clsTag.runtimeClass)
-
-  override def toRow(t: T): InternalRow = {
-    inputRow(0) = t
-    extractProjection(inputRow)
-  }
-
-  override def fromRow(row: InternalRow): T = {
-    constructProjection(row).get(0, dataType).asInstanceOf[T]
-  }
-
-  override def bind(schema: Seq[Attribute]): ClassEncoder[T] = {
-    val plan = Project(Alias(constructExpression, "object")() :: Nil, LocalRelation(schema))
-    val analyzedPlan = SimpleAnalyzer.execute(plan)
-    val resolvedExpression = analyzedPlan.expressions.head.children.head
-    val boundExpression = BindReferences.bindReference(resolvedExpression, schema)
-
-    copy(constructExpression = boundExpression)
-  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/ccf536f9/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
new file mode 100644
index 0000000..3e74aab
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
@@ -0,0 +1,234 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.encoders
+
+import scala.collection.Map
+import scala.reflect.ClassTag
+
+import org.apache.spark.sql.Row
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.util.DateTimeUtils
+import org.apache.spark.sql.types._
+import org.apache.spark.unsafe.types.UTF8String
+
+object RowEncoder {
+
+  def apply(schema: StructType): ClassEncoder[Row] = {
+    val cls = classOf[Row]
+    val inputObject = BoundReference(0, ObjectType(cls), nullable = true)
+    val extractExpressions = extractorsFor(inputObject, schema)
+    val constructExpression = constructorFor(schema)
+    new ClassEncoder[Row](
+      schema,
+      extractExpressions.asInstanceOf[CreateStruct].children,
+      constructExpression,
+      ClassTag(cls))
+  }
+
+  private def extractorsFor(
+      inputObject: Expression,
+      inputType: DataType): Expression = inputType match {
+    case BooleanType | ByteType | ShortType | IntegerType | LongType |
+         FloatType | DoubleType | BinaryType => inputObject
+
+    case TimestampType =>
+      StaticInvoke(
+        DateTimeUtils,
+        TimestampType,
+        "fromJavaTimestamp",
+        inputObject :: Nil)
+
+    case DateType =>
+      StaticInvoke(
+        DateTimeUtils,
+        DateType,
+        "fromJavaDate",
+        inputObject :: Nil)
+
+    case _: DecimalType =>
+      StaticInvoke(
+        Decimal,
+        DecimalType.SYSTEM_DEFAULT,
+        "apply",
+        inputObject :: Nil)
+
+    case StringType =>
+      StaticInvoke(
+        classOf[UTF8String],
+        StringType,
+        "fromString",
+        inputObject :: Nil)
+
+    case t @ ArrayType(et, _) => et match {
+      case BooleanType | ByteType | ShortType | IntegerType | LongType | FloatType | DoubleType =>
+        NewInstance(
+          classOf[GenericArrayData],
+          inputObject :: Nil,
+          dataType = t)
+      case _ => MapObjects(extractorsFor(_, et), inputObject, externalDataTypeFor(et))
+    }
+
+    case t @ MapType(kt, vt, valueNullable) =>
+      val keys =
+        Invoke(
+          Invoke(inputObject, "keysIterator", ObjectType(classOf[scala.collection.Iterator[_]])),
+          "toSeq",
+          ObjectType(classOf[scala.collection.Seq[_]]))
+      val convertedKeys = extractorsFor(keys, ArrayType(kt, false))
+
+      val values =
+        Invoke(
+          Invoke(inputObject, "valuesIterator", ObjectType(classOf[scala.collection.Iterator[_]])),
+          "toSeq",
+          ObjectType(classOf[scala.collection.Seq[_]]))
+      val convertedValues = extractorsFor(values, ArrayType(vt, valueNullable))
+
+      NewInstance(
+        classOf[ArrayBasedMapData],
+        convertedKeys :: convertedValues :: Nil,
+        dataType = t)
+
+    case StructType(fields) =>
+      val convertedFields = fields.zipWithIndex.map { case (f, i) =>
+        If(
+          Invoke(inputObject, "isNullAt", BooleanType, Literal(i) :: Nil),
+          Literal.create(null, f.dataType),
+          extractorsFor(
+            Invoke(inputObject, "get", externalDataTypeFor(f.dataType), Literal(i) :: Nil),
+            f.dataType))
+      }
+      CreateStruct(convertedFields)
+  }
+
+  private def externalDataTypeFor(dt: DataType): DataType = dt match {
+    case BooleanType | ByteType | ShortType | IntegerType | LongType |
+         FloatType | DoubleType | BinaryType => dt
+    case TimestampType => ObjectType(classOf[java.sql.Timestamp])
+    case DateType => ObjectType(classOf[java.sql.Date])
+    case _: DecimalType => ObjectType(classOf[java.math.BigDecimal])
+    case StringType => ObjectType(classOf[java.lang.String])
+    case _: ArrayType => ObjectType(classOf[scala.collection.Seq[_]])
+    case _: MapType => ObjectType(classOf[scala.collection.Map[_, _]])
+    case _: StructType => ObjectType(classOf[Row])
+  }
+
+  private def constructorFor(schema: StructType): Expression = {
+    val fields = schema.zipWithIndex.map { case (f, i) =>
+      val field = BoundReference(i, f.dataType, f.nullable)
+      If(
+        IsNull(field),
+        Literal.create(null, externalDataTypeFor(f.dataType)),
+        constructorFor(BoundReference(i, f.dataType, f.nullable), f.dataType)
+      )
+    }
+    CreateRow(fields)
+  }
+
+  private def constructorFor(input: Expression, dataType: DataType): Expression = dataType match {
+    case BooleanType | ByteType | ShortType | IntegerType | LongType |
+         FloatType | DoubleType | BinaryType => input
+
+    case TimestampType =>
+      StaticInvoke(
+        DateTimeUtils,
+        ObjectType(classOf[java.sql.Timestamp]),
+        "toJavaTimestamp",
+        input :: Nil)
+
+    case DateType =>
+      StaticInvoke(
+        DateTimeUtils,
+        ObjectType(classOf[java.sql.Date]),
+        "toJavaDate",
+        input :: Nil)
+
+    case _: DecimalType =>
+      Invoke(input, "toJavaBigDecimal", ObjectType(classOf[java.math.BigDecimal]))
+
+    case StringType =>
+      Invoke(input, "toString", ObjectType(classOf[String]))
+
+    case ArrayType(et, nullable) =>
+      val arrayData =
+        Invoke(
+          MapObjects(constructorFor(_, et), input, et),
+          "array",
+          ObjectType(classOf[Array[_]]))
+      StaticInvoke(
+        scala.collection.mutable.WrappedArray,
+        ObjectType(classOf[Seq[_]]),
+        "make",
+        arrayData :: Nil)
+
+    case MapType(kt, vt, valueNullable) =>
+      val keyArrayType = ArrayType(kt, false)
+      val keyData = constructorFor(Invoke(input, "keyArray", keyArrayType), keyArrayType)
+
+      val valueArrayType = ArrayType(vt, valueNullable)
+      val valueData = constructorFor(Invoke(input, "valueArray", valueArrayType), valueArrayType)
+
+      StaticInvoke(
+        ArrayBasedMapData,
+        ObjectType(classOf[Map[_, _]]),
+        "toScalaMap",
+        keyData :: valueData :: Nil)
+
+    case StructType(fields) =>
+      val convertedFields = fields.zipWithIndex.map { case (f, i) =>
+        If(
+          Invoke(input, "isNullAt", BooleanType, Literal(i) :: Nil),
+          Literal.create(null, externalDataTypeFor(f.dataType)),
+          constructorFor(getField(input, i, f.dataType), f.dataType))
+      }
+      CreateRow(convertedFields)
+  }
+
+  private def getField(
+     row: Expression,
+     ordinal: Int,
+     dataType: DataType): Expression = dataType match {
+    case BooleanType =>
+      Invoke(row, "getBoolean", dataType, Literal(ordinal) :: Nil)
+    case ByteType =>
+      Invoke(row, "getByte", dataType, Literal(ordinal) :: Nil)
+    case ShortType =>
+      Invoke(row, "getShort", dataType, Literal(ordinal) :: Nil)
+    case IntegerType | DateType =>
+      Invoke(row, "getInt", dataType, Literal(ordinal) :: Nil)
+    case LongType | TimestampType =>
+      Invoke(row, "getLong", dataType, Literal(ordinal) :: Nil)
+    case FloatType =>
+      Invoke(row, "getFloat", dataType, Literal(ordinal) :: Nil)
+    case DoubleType =>
+      Invoke(row, "getDouble", dataType, Literal(ordinal) :: Nil)
+    case t: DecimalType =>
+      Invoke(row, "getDecimal", dataType, Seq(ordinal, t.precision, t.scale).map(Literal(_)))
+    case StringType =>
+      Invoke(row, "getUTF8String", dataType, Literal(ordinal) :: Nil)
+    case BinaryType =>
+      Invoke(row, "getBinary", dataType, Literal(ordinal) :: Nil)
+    case CalendarIntervalType =>
+      Invoke(row, "getInterval", dataType, Literal(ordinal) :: Nil)
+    case t: StructType =>
+      Invoke(row, "getStruct", dataType, Literal(ordinal) :: Literal(t.size) :: Nil)
+    case _: ArrayType =>
+      Invoke(row, "getArray", dataType, Literal(ordinal) :: Nil)
+    case _: MapType =>
+      Invoke(row, "getMap", dataType, Literal(ordinal) :: Nil)
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/ccf536f9/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala
index e8c1c93..8fc00ad 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala
@@ -17,12 +17,13 @@
 
 package org.apache.spark.sql.catalyst.expressions
 
+import org.apache.spark.sql.Row
 import org.apache.spark.sql.catalyst.analysis.SimpleAnalyzer
 import org.apache.spark.sql.catalyst.plans.logical.{Project, LocalRelation}
 
 import scala.language.existentials
 
-import org.apache.spark.sql.catalyst.{ScalaReflection, InternalRow}
+import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, CodeGenContext}
 import org.apache.spark.sql.types._
 
@@ -364,6 +365,10 @@ case class MapObjects(
       (".numElements()", (i: String) => s".getShort($i)", true)
     case ArrayType(BooleanType, _) =>
       (".numElements()", (i: String) => s".getBoolean($i)", true)
+    case ArrayType(StringType, _) =>
+      (".numElements()", (i: String) => s".getUTF8String($i)", false)
+    case ArrayType(_: MapType, _) =>
+      (".numElements()", (i: String) => s".getMap($i)", false)
   }
 
   override def nullable: Boolean = true
@@ -398,7 +403,7 @@ case class MapObjects(
     val convertedArray = ctx.freshName("convertedArray")
     val loopIndex = ctx.freshName("loopIndex")
 
-    val convertedType = ctx.javaType(boundFunction.dataType)
+    val convertedType = ctx.boxedType(boundFunction.dataType)
 
     // Because of the way Java defines nested arrays, we have to handle the syntax specially.
     // Specifically, we have to insert the [$dataLength] in between the type and any extra nested
@@ -434,9 +439,13 @@ case class MapObjects(
             ($elementJavaType)${genInputData.value}${itemAccessor(loopIndex)};
           $loopNullCheck
 
-          ${genFunction.code}
+          if ($loopIsNull) {
+            $convertedArray[$loopIndex] = null;
+          } else {
+            ${genFunction.code}
+            $convertedArray[$loopIndex] = ${genFunction.value};
+          }
 
-          $convertedArray[$loopIndex] = ($convertedType)${genFunction.value};
           $loopIndex += 1;
         }
 
@@ -446,3 +455,32 @@ case class MapObjects(
     """
   }
 }
+
+case class CreateRow(children: Seq[Expression]) extends Expression {
+  override def dataType: DataType = ObjectType(classOf[Row])
+
+  override def nullable: Boolean = false
+
+  override def eval(input: InternalRow): Any =
+    throw new UnsupportedOperationException("Only code-generated evaluation is supported")
+
+  override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
+    val rowClass = classOf[GenericRow].getName
+    val values = ctx.freshName("values")
+    s"""
+      boolean ${ev.isNull} = false;
+      final Object[] $values = new Object[${children.size}];
+    """ +
+      children.zipWithIndex.map { case (e, i) =>
+        val eval = e.gen(ctx)
+        eval.code + s"""
+          if (${eval.isNull}) {
+            $values[$i] = null;
+          } else {
+            $values[$i] = ${eval.value};
+          }
+         """
+      }.mkString("\n") +
+      s"final ${classOf[Row].getName} ${ev.value} = new $rowClass($values);"
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/ccf536f9/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayBasedMapData.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayBasedMapData.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayBasedMapData.scala
index 5f22e59..e5ffe32 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayBasedMapData.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayBasedMapData.scala
@@ -66,4 +66,8 @@ object ArrayBasedMapData {
   def toScalaMap(keys: Array[Any], values: Array[Any]): Map[Any, Any] = {
     keys.zip(values).toMap
   }
+
+  def toScalaMap(keys: Seq[Any], values: Seq[Any]): Map[Any, Any] = {
+    keys.zip(values).toMap
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/ccf536f9/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala
index e483950..7614f05 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala
@@ -148,7 +148,7 @@ object RandomDataGenerator {
         () => BigDecimal.apply(
           rand.nextLong() % math.pow(10, precision).toLong,
           scale,
-          new MathContext(precision)))
+          new MathContext(precision)).bigDecimal)
       case DoubleType => randomNumeric[Double](
         rand, r => longBitsToDouble(r.nextLong()), Seq(Double.MinValue, Double.MinPositiveValue,
           Double.MaxValue, Double.PositiveInfinity, Double.NegativeInfinity, Double.NaN, 0.0))
@@ -166,7 +166,7 @@ object RandomDataGenerator {
       case NullType => Some(() => null)
       case ArrayType(elementType, containsNull) => {
         forType(elementType, nullable = containsNull, seed = Some(rand.nextLong())).map {
-          elementGenerator => () => Array.fill(rand.nextInt(MAX_ARR_SIZE))(elementGenerator())
+          elementGenerator => () => Seq.fill(rand.nextInt(MAX_ARR_SIZE))(elementGenerator())
         }
       }
       case MapType(keyType, valueType, valueContainsNull) => {

http://git-wip-us.apache.org/repos/asf/spark/blob/ccf536f9/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala
new file mode 100644
index 0000000..6041b62
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala
@@ -0,0 +1,96 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.encoders
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.{RandomDataGenerator, Row}
+import org.apache.spark.sql.types._
+import org.apache.spark.unsafe.types.UTF8String
+
+class RowEncoderSuite extends SparkFunSuite {
+
+  private val structOfString = new StructType().add("str", StringType)
+  private val arrayOfString = ArrayType(StringType)
+  private val mapOfString = MapType(StringType, StringType)
+
+  encodeDecodeTest(
+    new StructType()
+      .add("boolean", BooleanType)
+      .add("byte", ByteType)
+      .add("short", ShortType)
+      .add("int", IntegerType)
+      .add("long", LongType)
+      .add("float", FloatType)
+      .add("double", DoubleType)
+      .add("decimal", DecimalType.SYSTEM_DEFAULT)
+      .add("string", StringType)
+      .add("binary", BinaryType)
+      .add("date", DateType)
+      .add("timestamp", TimestampType))
+
+  encodeDecodeTest(
+    new StructType()
+      .add("arrayOfString", arrayOfString)
+      .add("arrayOfArrayOfString", ArrayType(arrayOfString))
+      .add("arrayOfArrayOfInt", ArrayType(ArrayType(IntegerType)))
+      .add("arrayOfMap", ArrayType(mapOfString))
+      .add("arrayOfStruct", ArrayType(structOfString)))
+
+  encodeDecodeTest(
+    new StructType()
+      .add("mapOfIntAndString", MapType(IntegerType, StringType))
+      .add("mapOfStringAndArray", MapType(StringType, arrayOfString))
+      .add("mapOfArrayAndInt", MapType(arrayOfString, IntegerType))
+      .add("mapOfArray", MapType(arrayOfString, arrayOfString))
+      .add("mapOfStringAndStruct", MapType(StringType, structOfString))
+      .add("mapOfStructAndString", MapType(structOfString, StringType))
+      .add("mapOfStruct", MapType(structOfString, structOfString)))
+
+  encodeDecodeTest(
+    new StructType()
+      .add("structOfString", structOfString)
+      .add("structOfStructOfString", new StructType().add("struct", structOfString))
+      .add("structOfArray", new StructType().add("array", arrayOfString))
+      .add("structOfMap", new StructType().add("map", mapOfString))
+      .add("structOfArrayAndMap",
+        new StructType().add("array", arrayOfString).add("map", mapOfString)))
+
+  private def encodeDecodeTest(schema: StructType): Unit = {
+    test(s"encode/decode: ${schema.simpleString}") {
+      val encoder = RowEncoder(schema)
+      val inputGenerator = RandomDataGenerator.forType(schema).get
+
+      var input: Row = null
+      try {
+        for (_ <- 1 to 5) {
+          input = inputGenerator.apply().asInstanceOf[Row]
+          val row = encoder.toRow(input)
+          val convertedBack = encoder.fromRow(row)
+          assert(input == convertedBack)
+        }
+      } catch {
+        case e: Exception =>
+          fail(
+            s"""
+               |schema: ${schema.simpleString}
+               |input: ${input}
+             """.stripMargin, e)
+      }
+    }
+  }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org