You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2018/07/26 07:06:36 UTC

spark git commit: [SPARK-24878][SQL] Fix reverse function for array type of primitive type containing null.

Repository: spark
Updated Branches:
  refs/heads/master d2e7deb59 -> c9b233d41


[SPARK-24878][SQL] Fix reverse function for array type of primitive type containing null.

## What changes were proposed in this pull request?

If we use `reverse` function for array type of primitive type containing `null` and the child array is `UnsafeArrayData`, the function returns a wrong result because `UnsafeArrayData` doesn't define the behavior of re-assignment, especially we can't set a valid value after we set `null`.

## How was this patch tested?

Added some tests.

Author: Takuya UESHIN <ue...@databricks.com>

Closes #21830 from ueshin/issues/SPARK-24878/fix_reverse.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/c9b233d4
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/c9b233d4
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/c9b233d4

Branch: refs/heads/master
Commit: c9b233d4144790c3e57e1a1d1602ad5dc354e8a8
Parents: d2e7deb
Author: Takuya UESHIN <ue...@databricks.com>
Authored: Thu Jul 26 15:06:13 2018 +0800
Committer: Wenchen Fan <we...@databricks.com>
Committed: Thu Jul 26 15:06:13 2018 +0800

----------------------------------------------------------------------
 .../expressions/collectionOperations.scala      | 66 +++++++++++---------
 .../spark/sql/DataFrameFunctionsSuite.scala     | 63 +++++++++++++------
 2 files changed, 80 insertions(+), 49 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/c9b233d4/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
index f438748..b3d04bf 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
@@ -1244,46 +1244,50 @@ case class Reverse(child: Expression) extends UnaryExpression with ImplicitCastI
   }
 
   private def arrayCodeGen(ctx: CodegenContext, ev: ExprCode, childName: String): String = {
-    val length = ctx.freshName("length")
-    val javaElementType = CodeGenerator.javaType(elementType)
+
     val isPrimitiveType = CodeGenerator.isPrimitiveType(elementType)
 
+    val numElements = ctx.freshName("numElements")
+    val arrayData = ctx.freshName("arrayData")
+
     val initialization = if (isPrimitiveType) {
-      s"$childName.copy()"
+      ctx.createUnsafeArray(arrayData, numElements, elementType, s" $prettyName failed.")
     } else {
-      s"new ${classOf[GenericArrayData].getName()}(new Object[$length])"
-    }
-
-    val numberOfIterations = if (isPrimitiveType) s"$length / 2" else length
-
-    val swapAssigments = if (isPrimitiveType) {
-      val setFunc = "set" + CodeGenerator.primitiveTypeName(elementType)
-      val getCall = (index: String) => CodeGenerator.getValue(ev.value, elementType, index)
-      s"""|boolean isNullAtK = ${ev.value}.isNullAt(k);
-          |boolean isNullAtL = ${ev.value}.isNullAt(l);
-          |if(!isNullAtK) {
-          |  $javaElementType el = ${getCall("k")};
-          |  if(!isNullAtL) {
-          |    ${ev.value}.$setFunc(k, ${getCall("l")});
-          |  } else {
-          |    ${ev.value}.setNullAt(k);
-          |  }
-          |  ${ev.value}.$setFunc(l, el);
-          |} else if (!isNullAtL) {
-          |  ${ev.value}.$setFunc(k, ${getCall("l")});
-          |  ${ev.value}.setNullAt(l);
-          |}""".stripMargin
+      val arrayDataClass = classOf[GenericArrayData].getName
+      s"$arrayDataClass $arrayData = new $arrayDataClass(new Object[$numElements]);"
+    }
+
+    val i = ctx.freshName("i")
+    val j = ctx.freshName("j")
+
+    val getValue = CodeGenerator.getValue(childName, elementType, i)
+
+    val setFunc = if (isPrimitiveType) {
+      s"set${CodeGenerator.primitiveTypeName(elementType)}"
+    } else {
+      "update"
+    }
+
+    val assignment = if (isPrimitiveType && dataType.asInstanceOf[ArrayType].containsNull) {
+      s"""
+         |if ($childName.isNullAt($i)) {
+         |  $arrayData.setNullAt($j);
+         |} else {
+         |  $arrayData.$setFunc($j, $getValue);
+         |}
+       """.stripMargin
     } else {
-      s"${ev.value}.update(k, ${CodeGenerator.getValue(childName, elementType, "l")});"
+      s"$arrayData.$setFunc($j, $getValue);"
     }
 
     s"""
-       |final int $length = $childName.numElements();
-       |${ev.value} = $initialization;
-       |for(int k = 0; k < $numberOfIterations; k++) {
-       |  int l = $length - k - 1;
-       |  $swapAssigments
+       |final int $numElements = $childName.numElements();
+       |$initialization
+       |for (int $i = 0; $i < $numElements; $i++) {
+       |  int $j = $numElements - $i - 1;
+       |  $assignment
        |}
+       |${ev.value} = $arrayData;
      """.stripMargin
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/c9b233d4/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
index bf04251..5a7bd45 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
@@ -901,8 +901,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext {
     }
   }
 
-  test("reverse function") {
-    // String test cases
+  test("reverse function - string") {
     val oneRowDF = Seq(("Spark", 3215)).toDF("s", "i")
     def testString(): Unit = {
       checkAnswer(oneRowDF.select(reverse('s)), Seq(Row("krapS")))
@@ -917,37 +916,61 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext {
     // Test with cached relation, the Project will be evaluated with codegen
     oneRowDF.cache()
     testString()
+  }
 
-    // Array test cases (primitive-type elements)
-    val idf = Seq(
+  test("reverse function - array for primitive type not containing null") {
+    val idfNotContainsNull = Seq(
       Seq(1, 9, 8, 7),
       Seq(5, 8, 9, 7, 2),
       Seq.empty,
       null
     ).toDF("i")
 
-    def testArray(): Unit = {
+    def testArrayOfPrimitiveTypeNotContainsNull(): Unit = {
       checkAnswer(
-        idf.select(reverse('i)),
+        idfNotContainsNull.select(reverse('i)),
         Seq(Row(Seq(7, 8, 9, 1)), Row(Seq(2, 7, 9, 8, 5)), Row(Seq.empty), Row(null))
       )
       checkAnswer(
-        idf.selectExpr("reverse(i)"),
+        idfNotContainsNull.selectExpr("reverse(i)"),
         Seq(Row(Seq(7, 8, 9, 1)), Row(Seq(2, 7, 9, 8, 5)), Row(Seq.empty), Row(null))
       )
+    }
+
+    // Test with local relation, the Project will be evaluated without codegen
+    testArrayOfPrimitiveTypeNotContainsNull()
+    // Test with cached relation, the Project will be evaluated with codegen
+    idfNotContainsNull.cache()
+    testArrayOfPrimitiveTypeNotContainsNull()
+  }
+
+  test("reverse function - array for primitive type containing null") {
+    val idfContainsNull = Seq[Seq[Integer]](
+      Seq(1, 9, 8, null, 7),
+      Seq(null, 5, 8, 9, 7, 2),
+      Seq.empty,
+      null
+    ).toDF("i")
+
+    def testArrayOfPrimitiveTypeContainsNull(): Unit = {
+      checkAnswer(
+        idfContainsNull.select(reverse('i)),
+        Seq(Row(Seq(7, null, 8, 9, 1)), Row(Seq(2, 7, 9, 8, 5, null)), Row(Seq.empty), Row(null))
+      )
       checkAnswer(
-        idf.selectExpr("reverse(array(1, null, 2, null))"),
-        Seq.fill(idf.count().toInt)(Row(Seq(null, 2, null, 1)))
+        idfContainsNull.selectExpr("reverse(i)"),
+        Seq(Row(Seq(7, null, 8, 9, 1)), Row(Seq(2, 7, 9, 8, 5, null)), Row(Seq.empty), Row(null))
       )
     }
 
     // Test with local relation, the Project will be evaluated without codegen
-    testArray()
+    testArrayOfPrimitiveTypeContainsNull()
     // Test with cached relation, the Project will be evaluated with codegen
-    idf.cache()
-    testArray()
+    idfContainsNull.cache()
+    testArrayOfPrimitiveTypeContainsNull()
+  }
 
-    // Array test cases (non-primitive-type elements)
+  test("reverse function - array for non-primitive type") {
     val sdf = Seq(
       Seq("c", "a", "b"),
       Seq("b", null, "c", null),
@@ -975,14 +998,18 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext {
     // Test with cached relation, the Project will be evaluated with codegen
     sdf.cache()
     testArrayOfNonPrimitiveType()
+  }
 
-    // Error test cases
-    intercept[AnalysisException] {
-      oneRowDF.selectExpr("reverse(struct(1, 'a'))")
+  test("reverse function - data type mismatch") {
+    val ex1 = intercept[AnalysisException] {
+      sql("select reverse(struct(1, 'a'))")
     }
-    intercept[AnalysisException] {
-      oneRowDF.selectExpr("reverse(map(1, 'a'))")
+    assert(ex1.getMessage.contains("data type mismatch"))
+
+    val ex2 = intercept[AnalysisException] {
+      sql("select reverse(map(1, 'a'))")
     }
+    assert(ex2.getMessage.contains("data type mismatch"))
   }
 
   test("array position function") {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org