You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by li...@apache.org on 2015/12/10 08:11:21 UTC
spark git commit: [SPARK-12252][SPARK-12131][SQL] refactor MapObjects
to make it less hacky
Repository: spark
Updated Branches:
refs/heads/master bd2cd4f53 -> d8ec081c9
[SPARK-12252][SPARK-12131][SQL] refactor MapObjects to make it less hacky
in https://github.com/apache/spark/pull/10133 we found that, we shoud ensure the children of `TreeNode` are all accessible in the `productIterator`, or the behavior will be very confusing.
In this PR, I try to fix this problem by expsing the `loopVar`.
This also fixes SPARK-12131 which is caused by the hacky `MapObjects`.
Author: Wenchen Fan <we...@databricks.com>
Closes #10239 from cloud-fan/map-objects.
Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/d8ec081c
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/d8ec081c
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/d8ec081c
Branch: refs/heads/master
Commit: d8ec081c911a040f3fb523a68025928ae4afc906
Parents: bd2cd4f
Author: Wenchen Fan <we...@databricks.com>
Authored: Thu Dec 10 15:11:13 2015 +0800
Committer: Cheng Lian <li...@databricks.com>
Committed: Thu Dec 10 15:11:13 2015 +0800
----------------------------------------------------------------------
.../spark/sql/catalyst/ScalaReflection.scala | 4 --
.../sql/catalyst/encoders/RowEncoder.scala | 2 +-
.../sql/catalyst/expressions/objects.scala | 75 +++++++++-----------
.../encoders/ExpressionEncoderSuite.scala | 1 +
4 files changed, 35 insertions(+), 47 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/spark/blob/d8ec081c/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
index 9b6b5b8..9013fd0 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
@@ -414,10 +414,6 @@ object ScalaReflection extends ScalaReflection {
} else {
val clsName = getClassNameFromType(elementType)
val newPath = s"""- array element class: "$clsName"""" +: walkedTypePath
- // `MapObjects` will run `extractorFor` lazily, we need to eagerly call `extractorFor` here
- // to trigger the type check.
- extractorFor(inputObject, elementType, newPath)
-
MapObjects(extractorFor(_, elementType, newPath), input, externalDataType)
}
}
http://git-wip-us.apache.org/repos/asf/spark/blob/d8ec081c/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
index 67518f5..d34ec94 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
@@ -193,7 +193,7 @@ object RowEncoder {
case ArrayType(et, nullable) =>
val arrayData =
Invoke(
- MapObjects(constructorFor, input, et),
+ MapObjects(constructorFor(_), input, et),
"array",
ObjectType(classOf[Array[_]]))
StaticInvoke(
http://git-wip-us.apache.org/repos/asf/spark/blob/d8ec081c/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala
index e6ab9a3..b2facfd 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala
@@ -326,19 +326,28 @@ case class WrapOption(child: Expression)
* A place holder for the loop variable used in [[MapObjects]]. This should never be constructed
* manually, but will instead be passed into the provided lambda function.
*/
-case class LambdaVariable(value: String, isNull: String, dataType: DataType) extends Expression {
+case class LambdaVariable(value: String, isNull: String, dataType: DataType) extends LeafExpression
+ with Unevaluable {
- override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String =
- throw new UnsupportedOperationException("Only calling gen() is supported.")
+ override def nullable: Boolean = true
- override def children: Seq[Expression] = Nil
- override def gen(ctx: CodeGenContext): GeneratedExpressionCode =
+ override def gen(ctx: CodeGenContext): GeneratedExpressionCode = {
GeneratedExpressionCode(code = "", value = value, isNull = isNull)
+ }
+}
- override def nullable: Boolean = false
- override def eval(input: InternalRow): Any =
- throw new UnsupportedOperationException("Only code-generated evaluation is supported.")
+object MapObjects {
+ private val curId = new java.util.concurrent.atomic.AtomicInteger()
+ def apply(
+ function: Expression => Expression,
+ inputData: Expression,
+ elementType: DataType): MapObjects = {
+ val loopValue = "MapObjects_loopValue" + curId.getAndIncrement()
+ val loopIsNull = "MapObjects_loopIsNull" + curId.getAndIncrement()
+ val loopVar = LambdaVariable(loopValue, loopIsNull, elementType)
+ MapObjects(loopVar, function(loopVar), inputData)
+ }
}
/**
@@ -349,20 +358,16 @@ case class LambdaVariable(value: String, isNull: String, dataType: DataType) ext
* The following collection ObjectTypes are currently supported:
* Seq, Array, ArrayData, java.util.List
*
- * @param function A function that returns an expression, given an attribute that can be used
- * to access the current value. This is does as a lambda function so that
- * a unique attribute reference can be provided for each expression (thus allowing
- * us to nest multiple MapObject calls).
+ * @param loopVar A place holder that used as the loop variable when iterate the collection, and
+ * used as input for the `lambdaFunction`. It also carries the element type info.
+ * @param lambdaFunction A function that take the `loopVar` as input, and used as lambda function
+ * to handle collection elements.
* @param inputData An expression that when evaluted returns a collection object.
- * @param elementType The type of element in the collection, expressed as a DataType.
*/
case class MapObjects(
- function: AttributeReference => Expression,
- inputData: Expression,
- elementType: DataType) extends Expression {
-
- private lazy val loopAttribute = AttributeReference("loopVar", elementType)()
- private lazy val completeFunction = function(loopAttribute)
+ loopVar: LambdaVariable,
+ lambdaFunction: Expression,
+ inputData: Expression) extends Expression {
private def itemAccessorMethod(dataType: DataType): String => String = dataType match {
case NullType =>
@@ -402,37 +407,23 @@ case class MapObjects(
override def nullable: Boolean = true
- override def children: Seq[Expression] = completeFunction :: inputData :: Nil
+ override def children: Seq[Expression] = lambdaFunction :: inputData :: Nil
override def eval(input: InternalRow): Any =
throw new UnsupportedOperationException("Only code-generated evaluation is supported")
- override def dataType: DataType = ArrayType(completeFunction.dataType)
+ override def dataType: DataType = ArrayType(lambdaFunction.dataType)
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
val javaType = ctx.javaType(dataType)
- val elementJavaType = ctx.javaType(elementType)
+ val elementJavaType = ctx.javaType(loopVar.dataType)
val genInputData = inputData.gen(ctx)
-
- // Variables to hold the element that is currently being processed.
- val loopValue = ctx.freshName("loopValue")
- val loopIsNull = ctx.freshName("loopIsNull")
-
- val loopVariable = LambdaVariable(loopValue, loopIsNull, elementType)
- val substitutedFunction = completeFunction transform {
- case a: AttributeReference if a == loopAttribute => loopVariable
- }
- // A hack to run this through the analyzer (to bind extractions).
- val boundFunction =
- SimpleAnalyzer.execute(Project(Alias(substitutedFunction, "")() :: Nil, LocalRelation(Nil)))
- .expressions.head.children.head
-
- val genFunction = boundFunction.gen(ctx)
+ val genFunction = lambdaFunction.gen(ctx)
val dataLength = ctx.freshName("dataLength")
val convertedArray = ctx.freshName("convertedArray")
val loopIndex = ctx.freshName("loopIndex")
- val convertedType = ctx.boxedType(boundFunction.dataType)
+ val convertedType = ctx.boxedType(lambdaFunction.dataType)
// Because of the way Java defines nested arrays, we have to handle the syntax specially.
// Specifically, we have to insert the [$dataLength] in between the type and any extra nested
@@ -446,9 +437,9 @@ case class MapObjects(
}
val loopNullCheck = if (primitiveElement) {
- s"boolean $loopIsNull = ${genInputData.value}.isNullAt($loopIndex);"
+ s"boolean ${loopVar.isNull} = ${genInputData.value}.isNullAt($loopIndex);"
} else {
- s"boolean $loopIsNull = ${genInputData.isNull} || $loopValue == null;"
+ s"boolean ${loopVar.isNull} = ${genInputData.isNull} || ${loopVar.value} == null;"
}
s"""
@@ -464,11 +455,11 @@ case class MapObjects(
int $loopIndex = 0;
while ($loopIndex < $dataLength) {
- $elementJavaType $loopValue =
+ $elementJavaType ${loopVar.value} =
($elementJavaType)${genInputData.value}${itemAccessor(loopIndex)};
$loopNullCheck
- if ($loopIsNull) {
+ if (${loopVar.isNull}) {
$convertedArray[$loopIndex] = null;
} else {
${genFunction.code}
http://git-wip-us.apache.org/repos/asf/spark/blob/d8ec081c/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
index d6ca138..7233e0f 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
@@ -145,6 +145,7 @@ class ExpressionEncoderSuite extends SparkFunSuite {
case class InnerClass(i: Int)
productTest(InnerClass(1))
+ encodeDecodeTest(Array(InnerClass(1)), "array of inner class")
productTest(PrimitiveData(1, 1, 1, 1, 1, 1, true))
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org