You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by li...@apache.org on 2015/12/10 08:11:21 UTC

spark git commit: [SPARK-12252][SPARK-12131][SQL] refactor MapObjects to make it less hacky

Repository: spark
Updated Branches:
  refs/heads/master bd2cd4f53 -> d8ec081c9


[SPARK-12252][SPARK-12131][SQL] refactor MapObjects to make it less hacky

in https://github.com/apache/spark/pull/10133 we found that, we shoud ensure the children of `TreeNode` are all accessible in the `productIterator`, or the behavior will be very confusing.

In this PR, I try to fix this problem by expsing the `loopVar`.

This also fixes SPARK-12131 which is caused by the hacky `MapObjects`.

Author: Wenchen Fan <we...@databricks.com>

Closes #10239 from cloud-fan/map-objects.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/d8ec081c
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/d8ec081c
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/d8ec081c

Branch: refs/heads/master
Commit: d8ec081c911a040f3fb523a68025928ae4afc906
Parents: bd2cd4f
Author: Wenchen Fan <we...@databricks.com>
Authored: Thu Dec 10 15:11:13 2015 +0800
Committer: Cheng Lian <li...@databricks.com>
Committed: Thu Dec 10 15:11:13 2015 +0800

----------------------------------------------------------------------
 .../spark/sql/catalyst/ScalaReflection.scala    |  4 --
 .../sql/catalyst/encoders/RowEncoder.scala      |  2 +-
 .../sql/catalyst/expressions/objects.scala      | 75 +++++++++-----------
 .../encoders/ExpressionEncoderSuite.scala       |  1 +
 4 files changed, 35 insertions(+), 47 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/d8ec081c/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
index 9b6b5b8..9013fd0 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
@@ -414,10 +414,6 @@ object ScalaReflection extends ScalaReflection {
       } else {
         val clsName = getClassNameFromType(elementType)
         val newPath = s"""- array element class: "$clsName"""" +: walkedTypePath
-        // `MapObjects` will run `extractorFor` lazily, we need to eagerly call `extractorFor` here
-        // to trigger the type check.
-        extractorFor(inputObject, elementType, newPath)
-
         MapObjects(extractorFor(_, elementType, newPath), input, externalDataType)
       }
     }

http://git-wip-us.apache.org/repos/asf/spark/blob/d8ec081c/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
index 67518f5..d34ec94 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
@@ -193,7 +193,7 @@ object RowEncoder {
     case ArrayType(et, nullable) =>
       val arrayData =
         Invoke(
-          MapObjects(constructorFor, input, et),
+          MapObjects(constructorFor(_), input, et),
           "array",
           ObjectType(classOf[Array[_]]))
       StaticInvoke(

http://git-wip-us.apache.org/repos/asf/spark/blob/d8ec081c/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala
index e6ab9a3..b2facfd 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala
@@ -326,19 +326,28 @@ case class WrapOption(child: Expression)
  * A place holder for the loop variable used in [[MapObjects]].  This should never be constructed
  * manually, but will instead be passed into the provided lambda function.
  */
-case class LambdaVariable(value: String, isNull: String, dataType: DataType) extends Expression {
+case class LambdaVariable(value: String, isNull: String, dataType: DataType) extends LeafExpression
+  with Unevaluable {
 
-  override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String =
-    throw new UnsupportedOperationException("Only calling gen() is supported.")
+  override def nullable: Boolean = true
 
-  override def children: Seq[Expression] = Nil
-  override def gen(ctx: CodeGenContext): GeneratedExpressionCode =
+  override def gen(ctx: CodeGenContext): GeneratedExpressionCode = {
     GeneratedExpressionCode(code = "", value = value, isNull = isNull)
+  }
+}
 
-  override def nullable: Boolean = false
-  override def eval(input: InternalRow): Any =
-    throw new UnsupportedOperationException("Only code-generated evaluation is supported.")
+object MapObjects {
+  private val curId = new java.util.concurrent.atomic.AtomicInteger()
 
+  def apply(
+      function: Expression => Expression,
+      inputData: Expression,
+      elementType: DataType): MapObjects = {
+    val loopValue = "MapObjects_loopValue" + curId.getAndIncrement()
+    val loopIsNull = "MapObjects_loopIsNull" + curId.getAndIncrement()
+    val loopVar = LambdaVariable(loopValue, loopIsNull, elementType)
+    MapObjects(loopVar, function(loopVar), inputData)
+  }
 }
 
 /**
@@ -349,20 +358,16 @@ case class LambdaVariable(value: String, isNull: String, dataType: DataType) ext
  * The following collection ObjectTypes are currently supported:
  *   Seq, Array, ArrayData, java.util.List
  *
- * @param function A function that returns an expression, given an attribute that can be used
- *                 to access the current value.  This is does as a lambda function so that
- *                 a unique attribute reference can be provided for each expression (thus allowing
- *                 us to nest multiple MapObject calls).
+ * @param loopVar A place holder that used as the loop variable when iterate the collection, and
+ *                used as input for the `lambdaFunction`. It also carries the element type info.
+ * @param lambdaFunction A function that take the `loopVar` as input, and used as lambda function
+ *                       to handle collection elements.
  * @param inputData An expression that when evaluted returns a collection object.
- * @param elementType The type of element in the collection, expressed as a DataType.
  */
 case class MapObjects(
-    function: AttributeReference => Expression,
-    inputData: Expression,
-    elementType: DataType) extends Expression {
-
-  private lazy val loopAttribute = AttributeReference("loopVar", elementType)()
-  private lazy val completeFunction = function(loopAttribute)
+    loopVar: LambdaVariable,
+    lambdaFunction: Expression,
+    inputData: Expression) extends Expression {
 
   private def itemAccessorMethod(dataType: DataType): String => String = dataType match {
     case NullType =>
@@ -402,37 +407,23 @@ case class MapObjects(
 
   override def nullable: Boolean = true
 
-  override def children: Seq[Expression] = completeFunction :: inputData :: Nil
+  override def children: Seq[Expression] = lambdaFunction :: inputData :: Nil
 
   override def eval(input: InternalRow): Any =
     throw new UnsupportedOperationException("Only code-generated evaluation is supported")
 
-  override def dataType: DataType = ArrayType(completeFunction.dataType)
+  override def dataType: DataType = ArrayType(lambdaFunction.dataType)
 
   override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
     val javaType = ctx.javaType(dataType)
-    val elementJavaType = ctx.javaType(elementType)
+    val elementJavaType = ctx.javaType(loopVar.dataType)
     val genInputData = inputData.gen(ctx)
-
-    // Variables to hold the element that is currently being processed.
-    val loopValue = ctx.freshName("loopValue")
-    val loopIsNull = ctx.freshName("loopIsNull")
-
-    val loopVariable = LambdaVariable(loopValue, loopIsNull, elementType)
-    val substitutedFunction = completeFunction transform {
-      case a: AttributeReference if a == loopAttribute => loopVariable
-    }
-    // A hack to run this through the analyzer (to bind extractions).
-    val boundFunction =
-      SimpleAnalyzer.execute(Project(Alias(substitutedFunction, "")() :: Nil, LocalRelation(Nil)))
-        .expressions.head.children.head
-
-    val genFunction = boundFunction.gen(ctx)
+    val genFunction = lambdaFunction.gen(ctx)
     val dataLength = ctx.freshName("dataLength")
     val convertedArray = ctx.freshName("convertedArray")
     val loopIndex = ctx.freshName("loopIndex")
 
-    val convertedType = ctx.boxedType(boundFunction.dataType)
+    val convertedType = ctx.boxedType(lambdaFunction.dataType)
 
     // Because of the way Java defines nested arrays, we have to handle the syntax specially.
     // Specifically, we have to insert the [$dataLength] in between the type and any extra nested
@@ -446,9 +437,9 @@ case class MapObjects(
     }
 
     val loopNullCheck = if (primitiveElement) {
-      s"boolean $loopIsNull = ${genInputData.value}.isNullAt($loopIndex);"
+      s"boolean ${loopVar.isNull} = ${genInputData.value}.isNullAt($loopIndex);"
     } else {
-      s"boolean $loopIsNull = ${genInputData.isNull} || $loopValue == null;"
+      s"boolean ${loopVar.isNull} = ${genInputData.isNull} || ${loopVar.value} == null;"
     }
 
     s"""
@@ -464,11 +455,11 @@ case class MapObjects(
 
         int $loopIndex = 0;
         while ($loopIndex < $dataLength) {
-          $elementJavaType $loopValue =
+          $elementJavaType ${loopVar.value} =
             ($elementJavaType)${genInputData.value}${itemAccessor(loopIndex)};
           $loopNullCheck
 
-          if ($loopIsNull) {
+          if (${loopVar.isNull}) {
             $convertedArray[$loopIndex] = null;
           } else {
             ${genFunction.code}

http://git-wip-us.apache.org/repos/asf/spark/blob/d8ec081c/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
index d6ca138..7233e0f 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
@@ -145,6 +145,7 @@ class ExpressionEncoderSuite extends SparkFunSuite {
 
   case class InnerClass(i: Int)
   productTest(InnerClass(1))
+  encodeDecodeTest(Array(InnerClass(1)), "array of inner class")
 
   productTest(PrimitiveData(1, 1, 1, 1, 1, 1, true))
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org