You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by da...@apache.org on 2015/06/30 16:59:21 UTC

spark git commit: [SPARK-8590] [SQL] add code gen for ExtractValue

Repository: spark
Updated Branches:
  refs/heads/master 2ed0c0ac4 -> 08fab4843


[SPARK-8590] [SQL] add code gen for ExtractValue

TODO:  use array instead of Seq as internal representation for `ArrayType`

Author: Wenchen Fan <cl...@outlook.com>

Closes #6982 from cloud-fan/extract-value and squashes the following commits:

e203bc1 [Wenchen Fan] address comments
4da0f0b [Wenchen Fan] some clean up
f679969 [Wenchen Fan] fix bug
e64f942 [Wenchen Fan] remove generic
e3f8427 [Wenchen Fan] fix style and address comments
fc694e8 [Wenchen Fan] add code gen for extract value


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/08fab484
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/08fab484
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/08fab484

Branch: refs/heads/master
Commit: 08fab4843845136358f3a7251e8d90135126b419
Parents: 2ed0c0a
Author: Wenchen Fan <cl...@outlook.com>
Authored: Tue Jun 30 07:58:49 2015 -0700
Committer: Davies Liu <da...@databricks.com>
Committed: Tue Jun 30 07:58:49 2015 -0700

----------------------------------------------------------------------
 .../catalyst/expressions/BoundAttribute.scala   |   2 +-
 .../sql/catalyst/expressions/Expression.scala   |  46 +++++--
 .../sql/catalyst/expressions/ExtractValue.scala |  76 +++++++++--
 .../sql/catalyst/expressions/arithmetic.scala   |   6 +-
 .../expressions/codegen/CodeGenerator.scala     |  15 ++-
 .../codegen/GenerateMutableProjection.scala     |   2 +-
 .../spark/sql/catalyst/expressions/math.scala   |  13 +-
 .../sql/catalyst/expressions/predicates.scala   |   3 -
 .../spark/sql/catalyst/expressions/sets.scala   |   4 -
 .../spark/sql/catalyst/util/TypeUtils.scala     |   2 +-
 .../catalyst/expressions/ComplexTypeSuite.scala | 131 +++++++++++--------
 11 files changed, 199 insertions(+), 101 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/08fab484/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala
index 5db2fcf..dc0b4ac 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala
@@ -47,7 +47,7 @@ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean)
     s"""
         boolean ${ev.isNull} = i.isNullAt($ordinal);
         ${ctx.javaType(dataType)} ${ev.primitive} = ${ev.isNull} ?
-            ${ctx.defaultValue(dataType)} : (${ctx.getColumn(dataType, ordinal)});
+            ${ctx.defaultValue(dataType)} : (${ctx.getColumn("i", dataType, ordinal)});
     """
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/08fab484/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
index e5dc7b9..aed4892 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
@@ -179,9 +179,10 @@ abstract class BinaryExpression extends Expression with trees.BinaryNode[Express
   override def toString: String = s"($left $symbol $right)"
 
   override def isThreadSafe: Boolean = left.isThreadSafe && right.isThreadSafe
+
   /**
-   * Short hand for generating binary evaluation code, which depends on two sub-evaluations of
-   * the same type.  If either of the sub-expressions is null, the result of this computation
+   * Short hand for generating binary evaluation code.
+   * If either of the sub-expressions is null, the result of this computation
    * is assumed to be null.
    *
    * @param f accepts two variable names and returns Java code to compute the output.
@@ -190,15 +191,23 @@ abstract class BinaryExpression extends Expression with trees.BinaryNode[Express
       ctx: CodeGenContext,
       ev: GeneratedExpressionCode,
       f: (String, String) => String): String = {
-    // TODO: Right now some timestamp tests fail if we enforce this...
-    if (left.dataType != right.dataType) {
-      // log.warn(s"${left.dataType} != ${right.dataType}")
-    }
+    nullSafeCodeGen(ctx, ev, (result, eval1, eval2) => {
+      s"$result = ${f(eval1, eval2)};"
+    })
+  }
 
+  /**
+   * Short hand for generating binary evaluation code.
+   * If either of the sub-expressions is null, the result of this computation
+   * is assumed to be null.
+   */
+  protected def nullSafeCodeGen(
+      ctx: CodeGenContext,
+      ev: GeneratedExpressionCode,
+      f: (String, String, String) => String): String = {
     val eval1 = left.gen(ctx)
     val eval2 = right.gen(ctx)
-    val resultCode = f(eval1.primitive, eval2.primitive)
-
+    val resultCode = f(ev.primitive, eval1.primitive, eval2.primitive)
     s"""
       ${eval1.code}
       boolean ${ev.isNull} = ${eval1.isNull};
@@ -206,7 +215,7 @@ abstract class BinaryExpression extends Expression with trees.BinaryNode[Express
       if (!${ev.isNull}) {
         ${eval2.code}
         if (!${eval2.isNull}) {
-          ${ev.primitive} = $resultCode;
+          $resultCode
         } else {
           ${ev.isNull} = true;
         }
@@ -245,13 +254,26 @@ abstract class UnaryExpression extends Expression with trees.UnaryNode[Expressio
       ctx: CodeGenContext,
       ev: GeneratedExpressionCode,
       f: String => String): String = {
+    nullSafeCodeGen(ctx, ev, (result, eval) => {
+      s"$result = ${f(eval)};"
+    })
+  }
+
+  /**
+   * Called by unary expressions to generate a code block that returns null if its parent returns
+   * null, and if not not null, use `f` to generate the expression.
+   */
+  protected def nullSafeCodeGen(
+      ctx: CodeGenContext,
+      ev: GeneratedExpressionCode,
+      f: (String, String) => String): String = {
     val eval = child.gen(ctx)
-    // reuse the previous isNull
-    ev.isNull = eval.isNull
+    val resultCode = f(ev.primitive, eval.primitive)
     eval.code + s"""
+      boolean ${ev.isNull} = ${eval.isNull};
       ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
       if (!${ev.isNull}) {
-        ${ev.primitive} = ${f(eval.primitive)};
+        $resultCode
       }
     """
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/08fab484/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExtractValue.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExtractValue.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExtractValue.scala
index 4d7c95f..3020e7f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExtractValue.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExtractValue.scala
@@ -21,6 +21,7 @@ import scala.collection.Map
 
 import org.apache.spark.sql.AnalysisException
 import org.apache.spark.sql.catalyst.analysis._
+import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, CodeGenContext}
 import org.apache.spark.sql.types._
 
 object ExtractValue {
@@ -38,7 +39,7 @@ object ExtractValue {
   def apply(
       child: Expression,
       extraction: Expression,
-      resolver: Resolver): ExtractValue = {
+      resolver: Resolver): Expression = {
 
     (child.dataType, extraction) match {
       case (StructType(fields), NonNullLiteral(v, StringType)) =>
@@ -73,7 +74,7 @@ object ExtractValue {
   def unapply(g: ExtractValue): Option[(Expression, Expression)] = {
     g match {
       case o: ExtractValueWithOrdinal => Some((o.child, o.ordinal))
-      case _ => Some((g.child, null))
+      case s: ExtractValueWithStruct => Some((s.child, null))
     }
   }
 
@@ -101,11 +102,11 @@ object ExtractValue {
  * Note: concrete extract value expressions are created only by `ExtractValue.apply`,
  * we don't need to do type check for them.
  */
-trait ExtractValue extends UnaryExpression {
-  self: Product =>
+trait ExtractValue {
+  self: Expression =>
 }
 
-abstract class ExtractValueWithStruct extends ExtractValue {
+abstract class ExtractValueWithStruct extends UnaryExpression with ExtractValue {
   self: Product =>
 
   def field: StructField
@@ -125,6 +126,18 @@ case class GetStructField(child: Expression, field: StructField, ordinal: Int)
     val baseValue = child.eval(input).asInstanceOf[InternalRow]
     if (baseValue == null) null else baseValue(ordinal)
   }
+
+  override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
+    nullSafeCodeGen(ctx, ev, (result, eval) => {
+      s"""
+        if ($eval.isNullAt($ordinal)) {
+          ${ev.isNull} = true;
+        } else {
+          $result = ${ctx.getColumn(eval, dataType, ordinal)};
+        }
+      """
+    })
+  }
 }
 
 /**
@@ -137,6 +150,7 @@ case class GetArrayStructFields(
     containsNull: Boolean) extends ExtractValueWithStruct {
 
   override def dataType: DataType = ArrayType(field.dataType, containsNull)
+  override def nullable: Boolean = child.nullable || containsNull || field.nullable
 
   override def eval(input: InternalRow): Any = {
     val baseValue = child.eval(input).asInstanceOf[Seq[InternalRow]]
@@ -146,18 +160,39 @@ case class GetArrayStructFields(
       }
     }
   }
+
+  override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
+    val arraySeqClass = "scala.collection.mutable.ArraySeq"
+    // TODO: consider using Array[_] for ArrayType child to avoid
+    // boxing of primitives
+    nullSafeCodeGen(ctx, ev, (result, eval) => {
+      s"""
+        final int n = $eval.size();
+        final $arraySeqClass<Object> values = new $arraySeqClass<Object>(n);
+        for (int j = 0; j < n; j++) {
+          InternalRow row = (InternalRow) $eval.apply(j);
+          if (row != null && !row.isNullAt($ordinal)) {
+            values.update(j, ${ctx.getColumn("row", field.dataType, ordinal)});
+          }
+        }
+        $result = (${ctx.javaType(dataType)}) values;
+      """
+    })
+  }
 }
 
-abstract class ExtractValueWithOrdinal extends ExtractValue {
+abstract class ExtractValueWithOrdinal extends BinaryExpression with ExtractValue {
   self: Product =>
 
   def ordinal: Expression
+  def child: Expression
+
+  override def left: Expression = child
+  override def right: Expression = ordinal
 
   /** `Null` is returned for invalid ordinals. */
   override def nullable: Boolean = true
-  override def foldable: Boolean = child.foldable && ordinal.foldable
   override def toString: String = s"$child[$ordinal]"
-  override def children: Seq[Expression] = child :: ordinal :: Nil
 
   override def eval(input: InternalRow): Any = {
     val value = child.eval(input)
@@ -195,6 +230,19 @@ case class GetArrayItem(child: Expression, ordinal: Expression)
       baseValue(index)
     }
   }
+
+  override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
+    nullSafeCodeGen(ctx, ev, (result, eval1, eval2) => {
+      s"""
+        final int index = (int)$eval2;
+        if (index >= $eval1.size() || index < 0) {
+          ${ev.isNull} = true;
+        } else {
+          $result = (${ctx.boxedType(dataType)})$eval1.apply(index);
+        }
+      """
+    })
+  }
 }
 
 /**
@@ -209,4 +257,16 @@ case class GetMapValue(child: Expression, ordinal: Expression)
     val baseValue = value.asInstanceOf[Map[Any, _]]
     baseValue.get(ordinal).orNull
   }
+
+  override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
+    nullSafeCodeGen(ctx, ev, (result, eval1, eval2) => {
+      s"""
+        if ($eval1.contains($eval2)) {
+          $result = (${ctx.boxedType(dataType)})$eval1.apply($eval2);
+        } else {
+          ${ev.isNull} = true;
+        }
+      """
+    })
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/08fab484/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
index 3d4d9e2..ae765c1 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
@@ -82,8 +82,6 @@ case class Abs(child: Expression) extends UnaryArithmetic {
 abstract class BinaryArithmetic extends BinaryExpression {
   self: Product =>
 
-  /** Name of the function for this expression on a [[Decimal]] type. */
-  def decimalMethod: String = ""
 
   override def dataType: DataType = left.dataType
 
@@ -113,6 +111,10 @@ abstract class BinaryArithmetic extends BinaryExpression {
     }
   }
 
+  /** Name of the function for this expression on a [[Decimal]] type. */
+  def decimalMethod: String =
+    sys.error("BinaryArithmetics must override either decimalMethod or genCode")
+
   override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = dataType match {
     case dt: DecimalType =>
       defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1.$decimalMethod($eval2)")

http://git-wip-us.apache.org/repos/asf/spark/blob/08fab484/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
index 57e0bed..bf6a6a1 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
@@ -82,24 +82,24 @@ class CodeGenContext {
   /**
    * Returns the code to access a column in Row for a given DataType.
    */
-  def getColumn(dataType: DataType, ordinal: Int): String = {
+  def getColumn(row: String, dataType: DataType, ordinal: Int): String = {
     val jt = javaType(dataType)
     if (isPrimitiveType(jt)) {
-      s"i.get${primitiveTypeName(jt)}($ordinal)"
+      s"$row.get${primitiveTypeName(jt)}($ordinal)"
     } else {
-      s"($jt)i.apply($ordinal)"
+      s"($jt)$row.apply($ordinal)"
     }
   }
 
   /**
    * Returns the code to update a column in Row for a given DataType.
    */
-  def setColumn(dataType: DataType, ordinal: Int, value: String): String = {
+  def setColumn(row: String, dataType: DataType, ordinal: Int, value: String): String = {
     val jt = javaType(dataType)
     if (isPrimitiveType(jt)) {
-      s"set${primitiveTypeName(jt)}($ordinal, $value)"
+      s"$row.set${primitiveTypeName(jt)}($ordinal, $value)"
     } else {
-      s"update($ordinal, $value)"
+      s"$row.update($ordinal, $value)"
     }
   }
 
@@ -127,6 +127,9 @@ class CodeGenContext {
     case dt: DecimalType => decimalType
     case BinaryType => "byte[]"
     case StringType => stringType
+    case _: StructType => "InternalRow"
+    case _: ArrayType => s"scala.collection.Seq"
+    case _: MapType => s"scala.collection.Map"
     case dt: OpenHashSetUDT if dt.elementType == IntegerType => classOf[IntegerHashSet].getName
     case dt: OpenHashSetUDT if dt.elementType == LongType => classOf[LongHashSet].getName
     case _ => "Object"

http://git-wip-us.apache.org/repos/asf/spark/blob/08fab484/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala
index 64ef357..addb802 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala
@@ -43,7 +43,7 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu
           if(${evaluationCode.isNull})
             mutableRow.setNullAt($i);
           else
-            mutableRow.${ctx.setColumn(e.dataType, i, evaluationCode.primitive)};
+            ${ctx.setColumn("mutableRow", e.dataType, i, evaluationCode.primitive)};
         """
     }.mkString("\n")
     val code = s"""

http://git-wip-us.apache.org/repos/asf/spark/blob/08fab484/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala
index a022f37..da63f2f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala
@@ -78,17 +78,14 @@ abstract class UnaryMathExpression(f: Double => Double, name: String)
   def funcName: String = name.toLowerCase
 
   override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
-    val eval = child.gen(ctx)
-    eval.code + s"""
-      boolean ${ev.isNull} = ${eval.isNull};
-      ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
-      if (!${ev.isNull}) {
-        ${ev.primitive} = java.lang.Math.${funcName}(${eval.primitive});
+    nullSafeCodeGen(ctx, ev, (result, eval) => {
+      s"""
+        ${ev.primitive} = java.lang.Math.${funcName}($eval);
         if (Double.valueOf(${ev.primitive}).isNaN()) {
           ${ev.isNull} = true;
         }
-      }
-    """
+      """
+    })
   }
 }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/08fab484/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
index 386cf6a..98cd5aa 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
@@ -69,10 +69,7 @@ trait PredicateHelper {
     expr.references.subsetOf(plan.outputSet)
 }
 
-
 case class Not(child: Expression) extends UnaryExpression with Predicate with AutoCastInputTypes {
-  override def foldable: Boolean = child.foldable
-  override def nullable: Boolean = child.nullable
   override def toString: String = s"NOT $child"
 
   override def expectedChildTypes: Seq[DataType] = Seq(BooleanType)

http://git-wip-us.apache.org/repos/asf/spark/blob/08fab484/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala
index efc6f50..daa9f44 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala
@@ -135,8 +135,6 @@ case class AddItemToSet(item: Expression, set: Expression) extends Expression {
  */
 case class CombineSets(left: Expression, right: Expression) extends BinaryExpression {
 
-  override def nullable: Boolean = left.nullable || right.nullable
-
   override def dataType: DataType = left.dataType
 
   override def symbol: String = "++="
@@ -185,8 +183,6 @@ case class CombineSets(left: Expression, right: Expression) extends BinaryExpres
  */
 case class CountSet(child: Expression) extends UnaryExpression {
 
-  override def nullable: Boolean = child.nullable
-
   override def dataType: DataType = LongType
 
   override def eval(input: InternalRow): Any = {

http://git-wip-us.apache.org/repos/asf/spark/blob/08fab484/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala
index 8656cc3..3148309 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala
@@ -21,7 +21,7 @@ import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
 import org.apache.spark.sql.types._
 
 /**
- * Helper function to check for valid data types
+ * Helper functions to check for valid data types.
  */
 object TypeUtils {
   def checkForNumericExpr(t: DataType, caller: String): TypeCheckResult = {

http://git-wip-us.apache.org/repos/asf/spark/blob/08fab484/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala
index b80911e..3515d04 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala
@@ -40,51 +40,42 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper {
   }
 
   test("GetArrayItem") {
+    val typeA = ArrayType(StringType)
+    val array = Literal.create(Seq("a", "b"), typeA)
     testIntegralDataTypes { convert =>
-      val array = Literal.create(Seq("a", "b"), ArrayType(StringType))
       checkEvaluation(GetArrayItem(array, Literal(convert(1))), "b")
     }
+    val nullArray = Literal.create(null, typeA)
+    val nullInt = Literal.create(null, IntegerType)
+    checkEvaluation(GetArrayItem(nullArray, Literal(1)), null)
+    checkEvaluation(GetArrayItem(array, nullInt), null)
+    checkEvaluation(GetArrayItem(nullArray, nullInt), null)
+
+    val nestedArray = Literal.create(Seq(Seq(1)), ArrayType(ArrayType(IntegerType)))
+    checkEvaluation(GetArrayItem(nestedArray, Literal(0)), Seq(1))
   }
 
-  test("CreateStruct") {
-    val row = InternalRow(1, 2, 3)
-    val c1 = 'a.int.at(0).as("a")
-    val c3 = 'c.int.at(2).as("c")
-    checkEvaluation(CreateStruct(Seq(c1, c3)), InternalRow(1, 3), row)
+  test("GetMapValue") {
+    val typeM = MapType(StringType, StringType)
+    val map = Literal.create(Map("a" -> "b"), typeM)
+    val nullMap = Literal.create(null, typeM)
+    val nullString = Literal.create(null, StringType)
+
+    checkEvaluation(GetMapValue(map, Literal("a")), "b")
+    checkEvaluation(GetMapValue(map, nullString), null)
+    checkEvaluation(GetMapValue(nullMap, nullString), null)
+    checkEvaluation(GetMapValue(map, nullString), null)
+
+    val nestedMap = Literal.create(Map("a" -> Map("b" -> "c")), MapType(StringType, typeM))
+    checkEvaluation(GetMapValue(nestedMap, Literal("a")), Map("b" -> "c"))
   }
 
-  test("complex type") {
-    val row = create_row(
-      "^Ba*n",                                // 0
-      null.asInstanceOf[UTF8String],          // 1
-      create_row("aa", "bb"),                 // 2
-      Map("aa"->"bb"),                        // 3
-      Seq("aa", "bb")                         // 4
-    )
-
-    val typeS = StructType(
-      StructField("a", StringType, true) :: StructField("b", StringType, true) :: Nil
-    )
-    val typeMap = MapType(StringType, StringType)
-    val typeArray = ArrayType(StringType)
-
-    checkEvaluation(GetMapValue(BoundReference(3, typeMap, true),
-      Literal("aa")), "bb", row)
-    checkEvaluation(GetMapValue(Literal.create(null, typeMap), Literal("aa")), null, row)
-    checkEvaluation(
-      GetMapValue(Literal.create(null, typeMap), Literal.create(null, StringType)), null, row)
-    checkEvaluation(GetMapValue(BoundReference(3, typeMap, true),
-      Literal.create(null, StringType)), null, row)
-
-    checkEvaluation(GetArrayItem(BoundReference(4, typeArray, true),
-      Literal(1)), "bb", row)
-    checkEvaluation(GetArrayItem(Literal.create(null, typeArray), Literal(1)), null, row)
-    checkEvaluation(
-      GetArrayItem(Literal.create(null, typeArray), Literal.create(null, IntegerType)), null, row)
-    checkEvaluation(GetArrayItem(BoundReference(4, typeArray, true),
-      Literal.create(null, IntegerType)), null, row)
-
-    def getStructField(expr: Expression, fieldName: String): ExtractValue = {
+  test("GetStructField") {
+    val typeS = StructType(StructField("a", IntegerType) :: Nil)
+    val struct = Literal.create(create_row(1), typeS)
+    val nullStruct = Literal.create(null, typeS)
+
+    def getStructField(expr: Expression, fieldName: String): GetStructField = {
       expr.dataType match {
         case StructType(fields) =>
           val field = fields.find(_.name == fieldName).get
@@ -92,28 +83,58 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper {
       }
     }
 
-    def quickResolve(u: UnresolvedExtractValue): ExtractValue = {
-      ExtractValue(u.child, u.extraction, _ == _)
-    }
+    checkEvaluation(getStructField(struct, "a"), 1)
+    checkEvaluation(getStructField(nullStruct, "a"), null)
+
+    val nestedStruct = Literal.create(create_row(create_row(1)),
+      StructType(StructField("a", typeS) :: Nil))
+    checkEvaluation(getStructField(nestedStruct, "a"), create_row(1))
+
+    val typeS_fieldNotNullable = StructType(StructField("a", IntegerType, false) :: Nil)
+    val struct_fieldNotNullable = Literal.create(create_row(1), typeS_fieldNotNullable)
+    val nullStruct_fieldNotNullable = Literal.create(null, typeS_fieldNotNullable)
+
+    assert(getStructField(struct_fieldNotNullable, "a").nullable === false)
+    assert(getStructField(struct, "a").nullable === true)
+    assert(getStructField(nullStruct_fieldNotNullable, "a").nullable === true)
+    assert(getStructField(nullStruct, "a").nullable === true)
+  }
 
-    checkEvaluation(getStructField(BoundReference(2, typeS, nullable = true), "a"), "aa", row)
-    checkEvaluation(getStructField(Literal.create(null, typeS), "a"), null, row)
+  test("GetArrayStructFields") {
+    val typeAS = ArrayType(StructType(StructField("a", IntegerType) :: Nil))
+    val arrayStruct = Literal.create(Seq(create_row(1)), typeAS)
+    val nullArrayStruct = Literal.create(null, typeAS)
 
-    val typeS_notNullable = StructType(
-      StructField("a", StringType, nullable = false)
-        :: StructField("b", StringType, nullable = false) :: Nil
-    )
+    def getArrayStructFields(expr: Expression, fieldName: String): GetArrayStructFields = {
+      expr.dataType match {
+        case ArrayType(StructType(fields), containsNull) =>
+          val field = fields.find(_.name == fieldName).get
+          GetArrayStructFields(expr, field, fields.indexOf(field), containsNull)
+      }
+    }
+
+    checkEvaluation(getArrayStructFields(arrayStruct, "a"), Seq(1))
+    checkEvaluation(getArrayStructFields(nullArrayStruct, "a"), null)
+  }
 
-    assert(getStructField(BoundReference(2, typeS, nullable = true), "a").nullable === true)
-    assert(getStructField(BoundReference(2, typeS_notNullable, nullable = false), "a").nullable
-      === false)
+  test("CreateStruct") {
+    val row = create_row(1, 2, 3)
+    val c1 = 'a.int.at(0).as("a")
+    val c3 = 'c.int.at(2).as("c")
+    checkEvaluation(CreateStruct(Seq(c1, c3)), create_row(1, 3), row)
+  }
 
-    assert(getStructField(Literal.create(null, typeS), "a").nullable === true)
-    assert(getStructField(Literal.create(null, typeS_notNullable), "a").nullable === true)
+  test("test dsl for complex type") {
+    def quickResolve(u: UnresolvedExtractValue): Expression = {
+      ExtractValue(u.child, u.extraction, _ == _)
+    }
 
-    checkEvaluation(quickResolve('c.map(typeMap).at(3).getItem("aa")), "bb", row)
-    checkEvaluation(quickResolve('c.array(typeArray.elementType).at(4).getItem(1)), "bb", row)
-    checkEvaluation(quickResolve('c.struct(typeS).at(2).getField("a")), "aa", row)
+    checkEvaluation(quickResolve('c.map(MapType(StringType, StringType)).at(0).getItem("a")),
+      "b", create_row(Map("a" -> "b")))
+    checkEvaluation(quickResolve('c.array(StringType).at(0).getItem(1)),
+      "b", create_row(Seq("a", "b")))
+    checkEvaluation(quickResolve('c.struct(StructField("a", IntegerType)).at(0).getField("a")),
+      1, create_row(create_row(1)))
   }
 
   test("error message of ExtractValue") {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org