You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by yh...@apache.org on 2015/10/08 23:28:18 UTC

spark git commit: [SPARK-10993] [SQL] Inital code generated encoder for product types

Repository: spark
Updated Branches:
  refs/heads/master a8226a9f1 -> 9e66a53c9


[SPARK-10993] [SQL] Inital code generated encoder for product types

This PR is a first cut at code generating an encoder that takes a Scala `Product` type and converts it directly into the tungsten binary format.  This is done through the addition of a new set of expression that can be used to invoke methods on raw JVM objects, extracting fields and converting the result into the required format.  These can then be used directly in an `UnsafeProjection` allowing us to leverage the existing encoding logic.

According to some simple benchmarks, this can significantly speed up conversion (~4x).  However, replacing CatalystConverters is deferred to a later PR to keep this PR at a reasonable size.

```scala
case class SomeInts(a: Int, b: Int, c: Int, d: Int, e: Int)

val data = SomeInts(1, 2, 3, 4, 5)
val encoder = ProductEncoder[SomeInts]
val converter = CatalystTypeConverters.createToCatalystConverter(ScalaReflection.schemaFor[SomeInts].dataType)

(1 to 5).foreach {iter =>
  benchmark(s"converter $iter") {
    var i = 100000000
    while (i > 0) {
      val res = converter(data).asInstanceOf[InternalRow]
      assert(res.getInt(0) == 1)
      assert(res.getInt(1) == 2)
      i -= 1
    }
  }

  benchmark(s"encoder $iter") {
    var i = 100000000
    while (i > 0) {
      val res = encoder.toRow(data)
      assert(res.getInt(0) == 1)
      assert(res.getInt(1) == 2)
      i -= 1
    }
  }
}
```

Results:
```
[info] converter 1: 7170ms
[info] encoder 1: 1888ms
[info] converter 2: 6763ms
[info] encoder 2: 1824ms
[info] converter 3: 6912ms
[info] encoder 3: 1802ms
[info] converter 4: 7131ms
[info] encoder 4: 1798ms
[info] converter 5: 7350ms
[info] encoder 5: 1912ms
```

Author: Michael Armbrust <mi...@databricks.com>

Closes #9019 from marmbrus/productEncoder.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/9e66a53c
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/9e66a53c
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/9e66a53c

Branch: refs/heads/master
Commit: 9e66a53c9955285a85c19f55c3ef62db2e1b868a
Parents: a8226a9
Author: Michael Armbrust <mi...@databricks.com>
Authored: Thu Oct 8 14:28:14 2015 -0700
Committer: Yin Huai <yh...@databricks.com>
Committed: Thu Oct 8 14:28:14 2015 -0700

----------------------------------------------------------------------
 .../spark/sql/catalyst/ScalaReflection.scala    | 238 ++++++++++++-
 .../spark/sql/catalyst/encoders/Encoder.scala   |  44 +++
 .../sql/catalyst/encoders/ProductEncoder.scala  |  67 ++++
 .../expressions/codegen/CodeGenerator.scala     |   4 +-
 .../sql/catalyst/expressions/objects.scala      | 334 +++++++++++++++++++
 .../spark/sql/types/GenericArrayData.scala      |   9 +
 .../org/apache/spark/sql/types/ObjectType.scala |  42 +++
 .../catalyst/encoders/ProductEncoderSuite.scala | 174 ++++++++++
 8 files changed, 910 insertions(+), 2 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/9e66a53c/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
index 2442341..8b733f2 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
@@ -17,6 +17,7 @@
 
 package org.apache.spark.sql.catalyst
 
+import org.apache.spark.sql.catalyst.util.DateTimeUtils
 import org.apache.spark.unsafe.types.UTF8String
 import org.apache.spark.util.Utils
 import org.apache.spark.sql.catalyst.expressions._
@@ -75,6 +76,242 @@ trait ScalaReflection {
    */
   private def localTypeOf[T: TypeTag]: `Type` = typeTag[T].in(mirror).tpe
 
+  /**
+   * Returns the Spark SQL DataType for a given scala type.  Where this is not an exact mapping
+   * to a native type, an ObjectType is returned. Special handling is also used for Arrays including
+   * those that hold primitive types.
+   */
+  def dataTypeFor(tpe: `Type`): DataType = tpe match {
+    case t if t <:< definitions.IntTpe => IntegerType
+    case t if t <:< definitions.LongTpe => LongType
+    case t if t <:< definitions.DoubleTpe => DoubleType
+    case t if t <:< definitions.FloatTpe => FloatType
+    case t if t <:< definitions.ShortTpe => ShortType
+    case t if t <:< definitions.ByteTpe => ByteType
+    case t if t <:< definitions.BooleanTpe => BooleanType
+    case t if t <:< localTypeOf[Array[Byte]] => BinaryType
+    case _ =>
+      val className: String = tpe.erasure.typeSymbol.asClass.fullName
+      className match {
+        case "scala.Array" =>
+          val TypeRef(_, _, Seq(arrayType)) = tpe
+          val cls = arrayType match {
+            case t if t <:< definitions.IntTpe => classOf[Array[Int]]
+            case t if t <:< definitions.LongTpe => classOf[Array[Long]]
+            case t if t <:< definitions.DoubleTpe => classOf[Array[Double]]
+            case t if t <:< definitions.FloatTpe => classOf[Array[Float]]
+            case t if t <:< definitions.ShortTpe => classOf[Array[Short]]
+            case t if t <:< definitions.ByteTpe => classOf[Array[Byte]]
+            case t if t <:< definitions.BooleanTpe => classOf[Array[Boolean]]
+            case other =>
+              // There is probably a better way to do this, but I couldn't find it...
+              val elementType = dataTypeFor(other).asInstanceOf[ObjectType].cls
+              java.lang.reflect.Array.newInstance(elementType, 1).getClass
+
+          }
+          ObjectType(cls)
+        case other => ObjectType(Utils.classForName(className))
+      }
+  }
+
+  /** Returns expressions for extracting all the fields from the given type. */
+  def extractorsFor[T : TypeTag](inputObject: Expression): Seq[Expression] = {
+    ScalaReflectionLock.synchronized {
+      extractorFor(inputObject, typeTag[T].tpe).asInstanceOf[CreateStruct].children
+    }
+  }
+
+  /** Helper for extracting internal fields from a case class. */
+  protected def extractorFor(
+      inputObject: Expression,
+      tpe: `Type`): Expression = ScalaReflectionLock.synchronized {
+    if (!inputObject.dataType.isInstanceOf[ObjectType]) {
+      inputObject
+    } else {
+      tpe match {
+        case t if t <:< localTypeOf[Option[_]] =>
+          val TypeRef(_, _, Seq(optType)) = t
+          optType match {
+            // For primitive types we must manually unbox the value of the object.
+            case t if t <:< definitions.IntTpe =>
+              Invoke(
+                UnwrapOption(ObjectType(classOf[java.lang.Integer]), inputObject),
+                "intValue",
+                IntegerType)
+            case t if t <:< definitions.LongTpe =>
+              Invoke(
+                UnwrapOption(ObjectType(classOf[java.lang.Long]), inputObject),
+                "longValue",
+                LongType)
+            case t if t <:< definitions.DoubleTpe =>
+              Invoke(
+                UnwrapOption(ObjectType(classOf[java.lang.Double]), inputObject),
+                "doubleValue",
+                DoubleType)
+            case t if t <:< definitions.FloatTpe =>
+              Invoke(
+                UnwrapOption(ObjectType(classOf[java.lang.Float]), inputObject),
+                "floatValue",
+                FloatType)
+            case t if t <:< definitions.ShortTpe =>
+              Invoke(
+                UnwrapOption(ObjectType(classOf[java.lang.Short]), inputObject),
+                "shortValue",
+                ShortType)
+            case t if t <:< definitions.ByteTpe =>
+              Invoke(
+                UnwrapOption(ObjectType(classOf[java.lang.Byte]), inputObject),
+                "byteValue",
+                ByteType)
+            case t if t <:< definitions.BooleanTpe =>
+              Invoke(
+                UnwrapOption(ObjectType(classOf[java.lang.Boolean]), inputObject),
+                "booleanValue",
+                BooleanType)
+
+            // For non-primitives, we can just extract the object from the Option and then recurse.
+            case other =>
+              val className: String = optType.erasure.typeSymbol.asClass.fullName
+              val classObj = Utils.classForName(className)
+              val optionObjectType = ObjectType(classObj)
+
+              val unwrapped = UnwrapOption(optionObjectType, inputObject)
+              expressions.If(
+                IsNull(unwrapped),
+                expressions.Literal.create(null, schemaFor(optType).dataType),
+                extractorFor(unwrapped, optType))
+          }
+
+        case t if t <:< localTypeOf[Product] =>
+          val formalTypeArgs = t.typeSymbol.asClass.typeParams
+          val TypeRef(_, _, actualTypeArgs) = t
+          val constructorSymbol = t.member(nme.CONSTRUCTOR)
+          val params = if (constructorSymbol.isMethod) {
+            constructorSymbol.asMethod.paramss
+          } else {
+            // Find the primary constructor, and use its parameter ordering.
+            val primaryConstructorSymbol: Option[Symbol] =
+              constructorSymbol.asTerm.alternatives.find(s =>
+                s.isMethod && s.asMethod.isPrimaryConstructor)
+
+            if (primaryConstructorSymbol.isEmpty) {
+              sys.error("Internal SQL error: Product object did not have a primary constructor.")
+            } else {
+              primaryConstructorSymbol.get.asMethod.paramss
+            }
+          }
+
+          CreateStruct(params.head.map { p =>
+            val fieldName = p.name.toString
+            val fieldType = p.typeSignature.substituteTypes(formalTypeArgs, actualTypeArgs)
+            val fieldValue = Invoke(inputObject, fieldName, dataTypeFor(fieldType))
+            extractorFor(fieldValue, fieldType)
+          })
+
+        case t if t <:< localTypeOf[Array[_]] =>
+          val TypeRef(_, _, Seq(elementType)) = t
+          val elementDataType = dataTypeFor(elementType)
+          val Schema(dataType, nullable) = schemaFor(elementType)
+
+          if (!elementDataType.isInstanceOf[AtomicType]) {
+            MapObjects(extractorFor(_, elementType), inputObject, elementDataType)
+          } else {
+            NewInstance(
+              classOf[GenericArrayData],
+              inputObject :: Nil,
+              dataType = ArrayType(dataType, nullable))
+          }
+
+        case t if t <:< localTypeOf[Seq[_]] =>
+          val TypeRef(_, _, Seq(elementType)) = t
+          val elementDataType = dataTypeFor(elementType)
+          val Schema(dataType, nullable) = schemaFor(elementType)
+
+          if (!elementDataType.isInstanceOf[AtomicType]) {
+            MapObjects(extractorFor(_, elementType), inputObject, elementDataType)
+          } else {
+            NewInstance(
+              classOf[GenericArrayData],
+              inputObject :: Nil,
+              dataType = ArrayType(dataType, nullable))
+          }
+
+        case t if t <:< localTypeOf[Map[_, _]] =>
+          val TypeRef(_, _, Seq(keyType, valueType)) = t
+          val Schema(keyDataType, _) = schemaFor(keyType)
+          val Schema(valueDataType, valueNullable) = schemaFor(valueType)
+
+          val rawMap = inputObject
+          val keys =
+            NewInstance(
+              classOf[GenericArrayData],
+              Invoke(rawMap, "keys", ObjectType(classOf[scala.collection.GenIterable[_]])) :: Nil,
+              dataType = ObjectType(classOf[ArrayData]))
+          val values =
+            NewInstance(
+              classOf[GenericArrayData],
+              Invoke(rawMap, "values", ObjectType(classOf[scala.collection.GenIterable[_]])) :: Nil,
+              dataType = ObjectType(classOf[ArrayData]))
+          NewInstance(
+            classOf[ArrayBasedMapData],
+            keys :: values :: Nil,
+            dataType = MapType(keyDataType, valueDataType, valueNullable))
+
+        case t if t <:< localTypeOf[String] =>
+          StaticInvoke(
+            classOf[UTF8String],
+            StringType,
+            "fromString",
+            inputObject :: Nil)
+
+        case t if t <:< localTypeOf[java.sql.Timestamp] =>
+          StaticInvoke(
+            DateTimeUtils,
+            TimestampType,
+            "fromJavaTimestamp",
+            inputObject :: Nil)
+
+        case t if t <:< localTypeOf[java.sql.Date] =>
+          StaticInvoke(
+            DateTimeUtils,
+            DateType,
+            "fromJavaDate",
+            inputObject :: Nil)
+        case t if t <:< localTypeOf[BigDecimal] =>
+          StaticInvoke(
+            Decimal,
+            DecimalType.SYSTEM_DEFAULT,
+            "apply",
+            inputObject :: Nil)
+
+        case t if t <:< localTypeOf[java.math.BigDecimal] =>
+          StaticInvoke(
+            Decimal,
+            DecimalType.SYSTEM_DEFAULT,
+            "apply",
+            inputObject :: Nil)
+
+        case t if t <:< localTypeOf[java.lang.Integer] =>
+          Invoke(inputObject, "intValue", IntegerType)
+        case t if t <:< localTypeOf[java.lang.Long] =>
+          Invoke(inputObject, "longValue", LongType)
+        case t if t <:< localTypeOf[java.lang.Double] =>
+          Invoke(inputObject, "doubleValue", DoubleType)
+        case t if t <:< localTypeOf[java.lang.Float] =>
+          Invoke(inputObject, "floatValue", FloatType)
+        case t if t <:< localTypeOf[java.lang.Short] =>
+          Invoke(inputObject, "shortValue", ShortType)
+        case t if t <:< localTypeOf[java.lang.Byte] =>
+          Invoke(inputObject, "byteValue", ByteType)
+        case t if t <:< localTypeOf[java.lang.Boolean] =>
+          Invoke(inputObject, "booleanValue", BooleanType)
+
+        case other =>
+          throw new UnsupportedOperationException(s"Extractor for type $other is not supported")
+      }
+    }
+  }
+
   /** Returns a catalyst DataType and its nullability for the given Scala Type using reflection. */
   def schemaFor(tpe: `Type`): Schema = ScalaReflectionLock.synchronized {
     val className: String = tpe.erasure.typeSymbol.asClass.fullName
@@ -91,7 +328,6 @@ trait ScalaReflection {
       case t if t <:< localTypeOf[Option[_]] =>
         val TypeRef(_, _, Seq(optType)) = t
         Schema(schemaFor(optType).dataType, nullable = true)
-      // Need to decide if we actually need a special type here.
       case t if t <:< localTypeOf[Array[Byte]] => Schema(BinaryType, nullable = true)
       case t if t <:< localTypeOf[Array[_]] =>
         val TypeRef(_, _, Seq(elementType)) = t

http://git-wip-us.apache.org/repos/asf/spark/blob/9e66a53c/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/Encoder.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/Encoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/Encoder.scala
new file mode 100644
index 0000000..8dacfa9
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/Encoder.scala
@@ -0,0 +1,44 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.encoders
+
+import scala.reflect.ClassTag
+
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.types.StructType
+
+/**
+ * Used to convert a JVM object of type `T` to and from the internal Spark SQL representation.
+ *
+ * Encoders are not intended to be thread-safe and thus they are allow to avoid internal locking
+ * and reuse internal buffers to improve performance.
+ */
+trait Encoder[T] {
+  /** Returns the schema of encoding this type of object as a Row. */
+  def schema: StructType
+
+  /** A ClassTag that can be used to construct and Array to contain a collection of `T`. */
+  def clsTag: ClassTag[T]
+
+  /**
+   * Returns an encoded version of `t` as a Spark SQL row.  Note that multiple calls to
+   * toRow are allowed to return the same actual [[InternalRow]] object.  Thus, the caller should
+   * copy the result before making another call if required.
+   */
+  def toRow(t: T): InternalRow
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/9e66a53c/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoder.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoder.scala
new file mode 100644
index 0000000..a236136
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoder.scala
@@ -0,0 +1,67 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.encoders
+
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection
+
+import scala.reflect.ClassTag
+import scala.reflect.runtime.universe.{typeTag, TypeTag}
+
+import org.apache.spark.sql.catalyst.{ScalaReflection, InternalRow}
+import org.apache.spark.sql.types.{ObjectType, StructType}
+
+/**
+ * A factory for constructing encoders that convert Scala's product type to/from the Spark SQL
+ * internal binary representation.
+ */
+object ProductEncoder {
+  def apply[T <: Product : TypeTag]: Encoder[T] = {
+    // We convert the not-serializable TypeTag into StructType and ClassTag.
+    val schema = ScalaReflection.schemaFor[T].dataType.asInstanceOf[StructType]
+    val mirror = typeTag[T].mirror
+    val cls = mirror.runtimeClass(typeTag[T].tpe)
+
+    val inputObject = BoundReference(0, ObjectType(cls), nullable = true)
+    val extractExpressions = ScalaReflection.extractorsFor[T](inputObject)
+    new ClassEncoder[T](schema, extractExpressions, ClassTag[T](cls))
+  }
+}
+
+/**
+ * A generic encoder for JVM objects.
+ *
+ * @param schema The schema after converting `T` to a Spark SQL row.
+ * @param extractExpressions A set of expressions, one for each top-level field that can be used to
+ *                           extract the values from a raw object.
+ * @param clsTag A classtag for `T`.
+ */
+case class ClassEncoder[T](
+    schema: StructType,
+    extractExpressions: Seq[Expression],
+    clsTag: ClassTag[T])
+  extends Encoder[T] {
+
+  private val extractProjection = GenerateUnsafeProjection.generate(extractExpressions)
+  private val inputRow = new GenericMutableRow(1)
+
+  override def toRow(t: T): InternalRow = {
+    inputRow(0) = t
+    extractProjection(inputRow)
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/9e66a53c/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
index 2dd6804..a0fe5bd 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
@@ -177,6 +177,8 @@ class CodeGenContext {
     case _: MapType => "MapData"
     case dt: OpenHashSetUDT if dt.elementType == IntegerType => classOf[IntegerHashSet].getName
     case dt: OpenHashSetUDT if dt.elementType == LongType => classOf[LongHashSet].getName
+    case ObjectType(cls) if cls.isArray => s"${javaType(ObjectType(cls.getComponentType))}[]"
+    case ObjectType(cls) => cls.getName
     case _ => "Object"
   }
 
@@ -395,7 +397,7 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin
 
     logDebug({
       // Only add extra debugging info to byte code when we are going to print the source code.
-      evaluator.setDebuggingInformation(false, true, false)
+      evaluator.setDebuggingInformation(true, true, false)
       withLineNums
     })
 

http://git-wip-us.apache.org/repos/asf/spark/blob/9e66a53c/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala
new file mode 100644
index 0000000..e1f960a
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala
@@ -0,0 +1,334 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions
+
+import scala.language.existentials
+
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, CodeGenContext}
+import org.apache.spark.sql.types._
+
+/**
+ * Invokes a static function, returning the result.  By default, any of the arguments being null
+ * will result in returning null instead of calling the function.
+ *
+ * @param staticObject The target of the static call.  This can either be the object itself
+ *                     (methods defined on scala objects), or the class object
+ *                     (static methods defined in java).
+ * @param dataType The expected return type of the function call
+ * @param functionName The name of the method to call.
+ * @param arguments An optional list of expressions to pass as arguments to the function.
+ * @param propagateNull When true, and any of the arguments is null, null will be returned instead
+ *                      of calling the function.
+ */
+case class StaticInvoke(
+    staticObject: Any,
+    dataType: DataType,
+    functionName: String,
+    arguments: Seq[Expression] = Nil,
+    propagateNull: Boolean = true) extends Expression {
+
+  val objectName = staticObject match {
+    case c: Class[_] => c.getName
+    case other => other.getClass.getName.stripSuffix("$")
+  }
+  override def nullable: Boolean = true
+  override def children: Seq[Expression] = Nil
+
+  override def eval(input: InternalRow): Any =
+    throw new UnsupportedOperationException("Only code-generated evaluation is supported.")
+
+  override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
+    val javaType = ctx.javaType(dataType)
+    val argGen = arguments.map(_.gen(ctx))
+    val argString = argGen.map(_.value).mkString(", ")
+
+    if (propagateNull) {
+      val objNullCheck = if (ctx.defaultValue(dataType) == "null") {
+        s"${ev.isNull} = ${ev.value} == null;"
+      } else {
+        ""
+      }
+
+      val argsNonNull = s"!(${argGen.map(_.isNull).mkString(" || ")})"
+      s"""
+        ${argGen.map(_.code).mkString("\n")}
+
+        boolean ${ev.isNull} = true;
+        $javaType ${ev.value} = ${ctx.defaultValue(dataType)};
+
+        if ($argsNonNull) {
+          ${ev.value} = $objectName.$functionName($argString);
+          $objNullCheck
+        }
+       """
+    } else {
+      s"""
+        ${argGen.map(_.code).mkString("\n")}
+
+        final boolean ${ev.isNull} = ${ev.value} == null;
+        $javaType ${ev.value} = $objectName.$functionName($argString);
+      """
+    }
+  }
+}
+
+/**
+ * Calls the specified function on an object, optionally passing arguments.  If the `targetObject`
+ * expression evaluates to null then null will be returned.
+ *
+ * @param targetObject An expression that will return the object to call the method on.
+ * @param functionName The name of the method to call.
+ * @param dataType The expected return type of the function.
+ * @param arguments An optional list of expressions, whos evaluation will be passed to the function.
+ */
+case class Invoke(
+    targetObject: Expression,
+    functionName: String,
+    dataType: DataType,
+    arguments: Seq[Expression] = Nil) extends Expression {
+
+  override def nullable: Boolean = true
+  override def children: Seq[Expression] = targetObject :: Nil
+
+  override def eval(input: InternalRow): Any =
+    throw new UnsupportedOperationException("Only code-generated evaluation is supported.")
+
+  override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
+    val javaType = ctx.javaType(dataType)
+    val obj = targetObject.gen(ctx)
+    val argGen = arguments.map(_.gen(ctx))
+    val argString = argGen.map(_.value).mkString(", ")
+
+    // If the function can return null, we do an extra check to make sure our null bit is still set
+    // correctly.
+    val objNullCheck = if (ctx.defaultValue(dataType) == "null") {
+      s"${ev.isNull} = ${ev.value} == null;"
+    } else {
+      ""
+    }
+
+    s"""
+      ${obj.code}
+      ${argGen.map(_.code).mkString("\n")}
+
+      boolean ${ev.isNull} = ${obj.value} == null;
+      $javaType ${ev.value} =
+        ${ev.isNull} ?
+        ${ctx.defaultValue(dataType)} : ($javaType) ${obj.value}.$functionName($argString);
+      $objNullCheck
+    """
+  }
+}
+
+/**
+ * Constructs a new instance of the given class, using the result of evaluating the specified
+ * expressions as arguments.
+ *
+ * @param cls The class to construct.
+ * @param arguments A list of expression to use as arguments to the constructor.
+ * @param propagateNull When true, if any of the arguments is null, then null will be returned
+ *                      instead of trying to construct the object.
+ * @param dataType The type of object being constructed, as a Spark SQL datatype.  This allows you
+ *                 to manually specify the type when the object in question is a valid internal
+ *                 representation (i.e. ArrayData) instead of an object.
+ */
+case class NewInstance(
+    cls: Class[_],
+    arguments: Seq[Expression],
+    propagateNull: Boolean = true,
+    dataType: DataType) extends Expression {
+  private val className = cls.getName
+
+  override def nullable: Boolean = propagateNull
+
+  override def children: Seq[Expression] = arguments
+
+  override def eval(input: InternalRow): Any =
+    throw new UnsupportedOperationException("Only code-generated evaluation is supported.")
+
+  override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
+    val javaType = ctx.javaType(dataType)
+    val argGen = arguments.map(_.gen(ctx))
+    val argString = argGen.map(_.value).mkString(", ")
+
+    if (propagateNull) {
+      val objNullCheck = if (ctx.defaultValue(dataType) == "null") {
+        s"${ev.isNull} = ${ev.value} == null;"
+      } else {
+        ""
+      }
+
+      val argsNonNull = s"!(${argGen.map(_.isNull).mkString(" || ")})"
+      s"""
+        ${argGen.map(_.code).mkString("\n")}
+
+        boolean ${ev.isNull} = true;
+        $javaType ${ev.value} = ${ctx.defaultValue(dataType)};
+
+        if ($argsNonNull) {
+          ${ev.value} = new $className($argString);
+          ${ev.isNull} = false;
+        }
+       """
+    } else {
+      s"""
+        ${argGen.map(_.code).mkString("\n")}
+
+        final boolean ${ev.isNull} = ${ev.value} == null;
+        $javaType ${ev.value} = new $className($argString);
+      """
+    }
+  }
+}
+
+/**
+ * Given an expression that returns on object of type `Option[_]`, this expression unwraps the
+ * option into the specified Spark SQL datatype.  In the case of `None`, the nullbit is set instead.
+ *
+ * @param dataType The expected unwrapped option type.
+ * @param child An expression that returns an `Option`
+ */
+case class UnwrapOption(
+    dataType: DataType,
+    child: Expression) extends UnaryExpression with ExpectsInputTypes {
+
+  override def nullable: Boolean = true
+
+  override def children: Seq[Expression] = Nil
+
+  override def inputTypes: Seq[AbstractDataType] = ObjectType :: Nil
+
+  override def eval(input: InternalRow): Any =
+    throw new UnsupportedOperationException("Only code-generated evaluation is supported")
+
+  override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
+    val javaType = ctx.javaType(dataType)
+    val inputObject = child.gen(ctx)
+
+    s"""
+      ${inputObject.code}
+
+      boolean ${ev.isNull} = ${inputObject.value} == null || ${inputObject.value}.isEmpty();
+      $javaType ${ev.value} =
+        ${ev.isNull} ? ${ctx.defaultValue(dataType)} : ($javaType)${inputObject.value}.get();
+    """
+  }
+}
+
+case class LambdaVariable(value: String, isNull: String, dataType: DataType) extends Expression {
+
+  override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String =
+    throw new UnsupportedOperationException("Only calling gen() is supported.")
+
+  override def children: Seq[Expression] = Nil
+  override def gen(ctx: CodeGenContext): GeneratedExpressionCode =
+    GeneratedExpressionCode(code = "", value = value, isNull = isNull)
+
+  override def nullable: Boolean = false
+  override def eval(input: InternalRow): Any =
+    throw new UnsupportedOperationException("Only code-generated evaluation is supported.")
+
+}
+
+/**
+ * Applies the given expression to every element of a collection of items, returning the result
+ * as an ArrayType.  This is similar to a typical map operation, but where the lambda function
+ * is expressed using catalyst expressions.
+ *
+ * The following collection ObjectTypes are currently supported: Seq, Array
+ *
+ * @param function A function that returns an expression, given an attribute that can be used
+ *                 to access the current value.  This is does as a lambda function so that
+ *                 a unique attribute reference can be provided for each expression (thus allowing
+ *                 us to nest multiple MapObject calls).
+ * @param inputData An expression that when evaluted returns a collection object.
+ * @param elementType The type of element in the collection, expressed as a DataType.
+ */
+case class MapObjects(
+    function: AttributeReference => Expression,
+    inputData: Expression,
+    elementType: DataType) extends Expression {
+
+  private val loopAttribute = AttributeReference("loopVar", elementType)()
+  private val completeFunction = function(loopAttribute)
+
+  private val (lengthFunction, itemAccessor) = inputData.dataType match {
+    case ObjectType(cls) if cls.isAssignableFrom(classOf[Seq[_]]) =>
+      (".size()", (i: String) => s".apply($i)")
+    case ObjectType(cls) if cls.isArray =>
+      (".length", (i: String) => s"[$i]")
+  }
+
+  override def nullable: Boolean = true
+
+  override def children: Seq[Expression] = completeFunction :: inputData :: Nil
+
+  override def eval(input: InternalRow): Any =
+    throw new UnsupportedOperationException("Only code-generated evaluation is supported")
+
+  override def dataType: DataType = ArrayType(completeFunction.dataType)
+
+  override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
+    val javaType = ctx.javaType(dataType)
+    val elementJavaType = ctx.javaType(elementType)
+    val genInputData = inputData.gen(ctx)
+
+    // Variables to hold the element that is currently being processed.
+    val loopValue = ctx.freshName("loopValue")
+    val loopIsNull = ctx.freshName("loopIsNull")
+
+    val loopVariable = LambdaVariable(loopValue, loopIsNull, elementType)
+    val boundFunction = completeFunction transform {
+      case a: AttributeReference if a == loopAttribute => loopVariable
+    }
+
+    val genFunction = boundFunction.gen(ctx)
+    val dataLength = ctx.freshName("dataLength")
+    val convertedArray = ctx.freshName("convertedArray")
+    val loopIndex = ctx.freshName("loopIndex")
+
+    s"""
+      ${genInputData.code}
+
+      boolean ${ev.isNull} = ${genInputData.value} == null;
+      $javaType ${ev.value} = ${ctx.defaultValue(dataType)};
+
+      if (!${ev.isNull}) {
+        Object[] $convertedArray = null;
+        int $dataLength = ${genInputData.value}$lengthFunction;
+        $convertedArray = new Object[$dataLength];
+
+        int $loopIndex = 0;
+        while ($loopIndex < $dataLength) {
+          $elementJavaType $loopValue =
+            ($elementJavaType)${genInputData.value}${itemAccessor(loopIndex)};
+          boolean $loopIsNull = $loopValue == null;
+
+          ${genFunction.code}
+
+          $convertedArray[$loopIndex] = ${genFunction.value};
+          $loopIndex += 1;
+        }
+
+        ${ev.isNull} = false;
+        ${ev.value} = new ${classOf[GenericArrayData].getName}($convertedArray);
+      }
+    """
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/9e66a53c/sql/catalyst/src/main/scala/org/apache/spark/sql/types/GenericArrayData.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/GenericArrayData.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/GenericArrayData.scala
index 459fcb6..c381603 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/GenericArrayData.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/GenericArrayData.scala
@@ -22,6 +22,15 @@ import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
 
 class GenericArrayData(private[sql] val array: Array[Any]) extends ArrayData {
 
+  def this(seq: scala.collection.GenIterable[Any]) = this(seq.toArray)
+
+  // TODO: This is boxing.  We should specialize.
+  def this(primitiveArray: Array[Int]) = this(primitiveArray.toSeq)
+  def this(primitiveArray: Array[Long]) = this(primitiveArray.toSeq)
+  def this(primitiveArray: Array[Float]) = this(primitiveArray.toSeq)
+  def this(primitiveArray: Array[Double]) = this(primitiveArray.toSeq)
+  def this(primitiveArray: Array[Boolean]) = this(primitiveArray.toSeq)
+
   override def copy(): ArrayData = new GenericArrayData(array.clone())
 
   override def numElements(): Int = array.length

http://git-wip-us.apache.org/repos/asf/spark/blob/9e66a53c/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ObjectType.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ObjectType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ObjectType.scala
new file mode 100644
index 0000000..fca0b79
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ObjectType.scala
@@ -0,0 +1,42 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.types
+
+import scala.language.existentials
+
+private[sql] object ObjectType extends AbstractDataType {
+  override private[sql] def defaultConcreteType: DataType =
+    throw new UnsupportedOperationException("null literals can't be casted to ObjectType")
+
+  // No casting or comparison is supported.
+  override private[sql] def acceptsType(other: DataType): Boolean = false
+
+  override private[sql] def simpleString: String = "Object"
+}
+
+/**
+ * Represents a JVM object that is passing through Spark SQL expression evaluation.  Note this
+ * is only used internally while converting into the internal format and is not intended for use
+ * outside of the execution engine.
+ */
+private[sql] case class ObjectType(cls: Class[_]) extends DataType {
+  override def defaultSize: Int =
+    throw new UnsupportedOperationException("No size estimation available for objects.")
+
+  def asNullable: DataType = this
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/9e66a53c/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoderSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoderSuite.scala
new file mode 100644
index 0000000..99c993d
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoderSuite.scala
@@ -0,0 +1,174 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.encoders
+
+import java.sql.{Date, Timestamp}
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.catalyst.ScalaReflection._
+import org.apache.spark.sql.catalyst.expressions.UnsafeProjection
+import org.apache.spark.sql.catalyst._
+
+
+case class RepeatedStruct(s: Seq[PrimitiveData])
+
+case class NestedArray(a: Array[Array[Int]])
+
+class ProductEncoderSuite extends SparkFunSuite {
+
+  test("convert PrimitiveData to InternalRow") {
+    val inputData = PrimitiveData(1, 1, 1, 1, 1, 1, true)
+    val encoder = ProductEncoder[PrimitiveData]
+    val convertedData = encoder.toRow(inputData)
+
+    assert(convertedData.getInt(0) == 1)
+    assert(convertedData.getLong(1) == 1.toLong)
+    assert(convertedData.getDouble(2) == 1.toDouble)
+    assert(convertedData.getFloat(3) == 1.toFloat)
+    assert(convertedData.getShort(4) == 1.toShort)
+    assert(convertedData.getByte(5) == 1.toByte)
+    assert(convertedData.getBoolean(6) == true)
+  }
+
+  test("convert Some[_] to InternalRow") {
+    val primitiveData = PrimitiveData(1, 1, 1, 1, 1, 1, true)
+    val inputData = OptionalData(Some(2), Some(2), Some(2), Some(2), Some(2), Some(2), Some(true),
+      Some(primitiveData))
+
+    val encoder = ProductEncoder[OptionalData]
+    val convertedData = encoder.toRow(inputData)
+
+    assert(convertedData.getInt(0) == 2)
+    assert(convertedData.getLong(1) == 2.toLong)
+    assert(convertedData.getDouble(2) == 2.toDouble)
+    assert(convertedData.getFloat(3) == 2.toFloat)
+    assert(convertedData.getShort(4) == 2.toShort)
+    assert(convertedData.getByte(5) == 2.toByte)
+    assert(convertedData.getBoolean(6) == true)
+
+    val nestedRow = convertedData.getStruct(7, 7)
+    assert(nestedRow.getInt(0) == 1)
+    assert(nestedRow.getLong(1) == 1.toLong)
+    assert(nestedRow.getDouble(2) == 1.toDouble)
+    assert(nestedRow.getFloat(3) == 1.toFloat)
+    assert(nestedRow.getShort(4) == 1.toShort)
+    assert(nestedRow.getByte(5) == 1.toByte)
+    assert(nestedRow.getBoolean(6) == true)
+  }
+
+  test("convert None to InternalRow") {
+    val inputData = OptionalData(None, None, None, None, None, None, None, None)
+    val encoder = ProductEncoder[OptionalData]
+    val convertedData = encoder.toRow(inputData)
+
+    assert(convertedData.isNullAt(0))
+    assert(convertedData.isNullAt(1))
+    assert(convertedData.isNullAt(2))
+    assert(convertedData.isNullAt(3))
+    assert(convertedData.isNullAt(4))
+    assert(convertedData.isNullAt(5))
+    assert(convertedData.isNullAt(6))
+    assert(convertedData.isNullAt(7))
+  }
+
+  test("convert nullable but present data to InternalRow") {
+    val inputData = NullableData(
+      1, 1L, 1.0, 1.0f, 1.toShort, 1.toByte, true, "test", new java.math.BigDecimal(1), new Date(0),
+      new Timestamp(0), Array[Byte](1, 2, 3))
+
+    val encoder = ProductEncoder[NullableData]
+    val convertedData = encoder.toRow(inputData)
+
+    assert(convertedData.getInt(0) == 1)
+    assert(convertedData.getLong(1) == 1.toLong)
+    assert(convertedData.getDouble(2) == 1.toDouble)
+    assert(convertedData.getFloat(3) == 1.toFloat)
+    assert(convertedData.getShort(4) == 1.toShort)
+    assert(convertedData.getByte(5) == 1.toByte)
+    assert(convertedData.getBoolean(6) == true)
+  }
+
+  test("convert nullable data to InternalRow") {
+    val inputData =
+      NullableData(null, null, null, null, null, null, null, null, null, null, null, null)
+
+    val encoder = ProductEncoder[NullableData]
+    val convertedData = encoder.toRow(inputData)
+
+    assert(convertedData.isNullAt(0))
+    assert(convertedData.isNullAt(1))
+    assert(convertedData.isNullAt(2))
+    assert(convertedData.isNullAt(3))
+    assert(convertedData.isNullAt(4))
+    assert(convertedData.isNullAt(5))
+    assert(convertedData.isNullAt(6))
+    assert(convertedData.isNullAt(7))
+    assert(convertedData.isNullAt(8))
+    assert(convertedData.isNullAt(9))
+    assert(convertedData.isNullAt(10))
+    assert(convertedData.isNullAt(11))
+  }
+
+  test("convert repeated struct") {
+    val inputData = RepeatedStruct(PrimitiveData(1, 1, 1, 1, 1, 1, true) :: Nil)
+    val encoder = ProductEncoder[RepeatedStruct]
+
+    val converted = encoder.toRow(inputData)
+    val convertedStruct = converted.getArray(0).getStruct(0, 7)
+    assert(convertedStruct.getInt(0) == 1)
+    assert(convertedStruct.getLong(1) == 1.toLong)
+    assert(convertedStruct.getDouble(2) == 1.toDouble)
+    assert(convertedStruct.getFloat(3) == 1.toFloat)
+    assert(convertedStruct.getShort(4) == 1.toShort)
+    assert(convertedStruct.getByte(5) == 1.toByte)
+    assert(convertedStruct.getBoolean(6) == true)
+  }
+
+  test("convert nested seq") {
+    val convertedData = ProductEncoder[Tuple1[Seq[Seq[Int]]]].toRow(Tuple1(Seq(Seq(1))))
+    assert(convertedData.getArray(0).getArray(0).getInt(0) == 1)
+
+    val convertedData2 = ProductEncoder[Tuple1[Seq[Seq[Seq[Int]]]]].toRow(Tuple1(Seq(Seq(Seq(1)))))
+    assert(convertedData2.getArray(0).getArray(0).getArray(0).getInt(0) == 1)
+  }
+
+  test("convert nested array") {
+    val convertedData = ProductEncoder[Tuple1[Array[Array[Int]]]].toRow(Tuple1(Array(Array(1))))
+  }
+
+  test("convert complex") {
+    val inputData = ComplexData(
+      Seq(1, 2),
+      Array(1, 2),
+      1 :: 2 :: Nil,
+      Seq(new Integer(1), null, new Integer(2)),
+      Map(1 -> 2L),
+      Map(1 -> new java.lang.Long(2)),
+      PrimitiveData(1, 1, 1, 1, 1, 1, true),
+      Array(Array(1)))
+
+    val encoder = ProductEncoder[ComplexData]
+    val convertedData = encoder.toRow(inputData)
+
+    assert(!convertedData.isNullAt(0))
+    val seq = convertedData.getArray(0)
+    assert(seq.numElements() == 2)
+    assert(seq.getInt(0) == 1)
+    assert(seq.getInt(1) == 2)
+  }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org