You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ma...@apache.org on 2016/01/05 21:33:34 UTC

spark git commit: [SPARK-12439][SQL] Fix toCatalystArray and MapObjects

Repository: spark
Updated Branches:
  refs/heads/master 8ce645d4e -> d202ad2fc


[SPARK-12439][SQL] Fix toCatalystArray and MapObjects

JIRA: https://issues.apache.org/jira/browse/SPARK-12439

In toCatalystArray, we should look at the data type returned by dataTypeFor instead of silentSchemaFor, to determine if the element is native type. An obvious problem is when the element is Option[Int] class, catalsilentSchemaFor will return Int, then we will wrongly recognize the element is native type.

There is another problem when using Option as array element. When we encode data like Seq(Some(1), Some(2), None) with encoder, we will use MapObjects to construct an array for it later. But in MapObjects, we don't check if the return value of lambdaFunction is null or not. That causes a bug that the decoded data for Seq(Some(1), Some(2), None) would be Seq(1, 2, -1), instead of Seq(1, 2, null).

Author: Liang-Chi Hsieh <vi...@gmail.com>

Closes #10391 from viirya/fix-catalystarray.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/d202ad2f
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/d202ad2f
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/d202ad2f

Branch: refs/heads/master
Commit: d202ad2fc24b54de38ad7bfb646bf7703069e9f7
Parents: 8ce645d
Author: Liang-Chi Hsieh <vi...@gmail.com>
Authored: Tue Jan 5 12:33:21 2016 -0800
Committer: Michael Armbrust <mi...@databricks.com>
Committed: Tue Jan 5 12:33:21 2016 -0800

----------------------------------------------------------------------
 .../org/apache/spark/sql/catalyst/ScalaReflection.scala  |  2 +-
 .../apache/spark/sql/catalyst/encoders/RowEncoder.scala  | 11 ++++++++---
 .../apache/spark/sql/catalyst/expressions/objects.scala  |  4 ++--
 .../sql/catalyst/encoders/ExpressionEncoderSuite.scala   |  3 +++
 4 files changed, 14 insertions(+), 6 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/d202ad2f/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
index c6aa60b..b0efdf3 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
@@ -405,7 +405,7 @@ object ScalaReflection extends ScalaReflection {
     def toCatalystArray(input: Expression, elementType: `Type`): Expression = {
       val externalDataType = dataTypeFor(elementType)
       val Schema(catalystType, nullable) = silentSchemaFor(elementType)
-      if (isNativeType(catalystType)) {
+      if (isNativeType(externalDataType)) {
         NewInstance(
           classOf[GenericArrayData],
           input :: Nil,

http://git-wip-us.apache.org/repos/asf/spark/blob/d202ad2f/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
index 6f3d5ba..3903086 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
@@ -35,7 +35,8 @@ object RowEncoder {
   def apply(schema: StructType): ExpressionEncoder[Row] = {
     val cls = classOf[Row]
     val inputObject = BoundReference(0, ObjectType(cls), nullable = true)
-    val extractExpressions = extractorsFor(inputObject, schema)
+    // We use an If expression to wrap extractorsFor result of StructType
+    val extractExpressions = extractorsFor(inputObject, schema).asInstanceOf[If].falseValue
     val constructExpression = constructorFor(schema)
     new ExpressionEncoder[Row](
       schema,
@@ -129,7 +130,9 @@ object RowEncoder {
             Invoke(inputObject, method, externalDataTypeFor(f.dataType), Literal(i) :: Nil),
             f.dataType))
       }
-      CreateStruct(convertedFields)
+      If(IsNull(inputObject),
+        Literal.create(null, inputType),
+        CreateStruct(convertedFields))
   }
 
   private def externalDataTypeFor(dt: DataType): DataType = dt match {
@@ -220,6 +223,8 @@ object RowEncoder {
           Literal.create(null, externalDataTypeFor(f.dataType)),
           constructorFor(GetStructField(input, i)))
       }
-      CreateExternalRow(convertedFields)
+      If(IsNull(input),
+        Literal.create(null, externalDataTypeFor(input.dataType)),
+        CreateExternalRow(convertedFields))
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/d202ad2f/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala
index fb404c1..c0c3e6e 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala
@@ -456,10 +456,10 @@ case class MapObjects(
             ($elementJavaType)${genInputData.value}${itemAccessor(loopIndex)};
           $loopNullCheck
 
-          if (${loopVar.isNull}) {
+          ${genFunction.code}
+          if (${genFunction.isNull}) {
             $convertedArray[$loopIndex] = null;
           } else {
-            ${genFunction.code}
             $convertedArray[$loopIndex] = ${genFunction.value};
           }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/d202ad2f/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
index 6453f1c..98f29e5 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
@@ -160,6 +160,9 @@ class ExpressionEncoderSuite extends SparkFunSuite {
 
   productTest(OptionalData(None, None, None, None, None, None, None, None))
 
+  encodeDecodeTest(Seq(Some(1), None), "Option in array")
+  encodeDecodeTest(Map(1 -> Some(10L), 2 -> Some(20L), 3 -> None), "Option in map")
+
   productTest(BoxedData(1, 1L, 1.0, 1.0f, 1.toShort, 1.toByte, true))
 
   productTest(BoxedData(null, null, null, null, null, null, null))


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org