You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2018/07/26 07:06:36 UTC
spark git commit: [SPARK-24878][SQL] Fix reverse function for array
type of primitive type containing null.
Repository: spark
Updated Branches:
refs/heads/master d2e7deb59 -> c9b233d41
[SPARK-24878][SQL] Fix reverse function for array type of primitive type containing null.
## What changes were proposed in this pull request?
If we use `reverse` function for array type of primitive type containing `null` and the child array is `UnsafeArrayData`, the function returns a wrong result because `UnsafeArrayData` doesn't define the behavior of re-assignment, especially we can't set a valid value after we set `null`.
## How was this patch tested?
Added some tests.
Author: Takuya UESHIN <ue...@databricks.com>
Closes #21830 from ueshin/issues/SPARK-24878/fix_reverse.
Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/c9b233d4
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/c9b233d4
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/c9b233d4
Branch: refs/heads/master
Commit: c9b233d4144790c3e57e1a1d1602ad5dc354e8a8
Parents: d2e7deb
Author: Takuya UESHIN <ue...@databricks.com>
Authored: Thu Jul 26 15:06:13 2018 +0800
Committer: Wenchen Fan <we...@databricks.com>
Committed: Thu Jul 26 15:06:13 2018 +0800
----------------------------------------------------------------------
.../expressions/collectionOperations.scala | 66 +++++++++++---------
.../spark/sql/DataFrameFunctionsSuite.scala | 63 +++++++++++++------
2 files changed, 80 insertions(+), 49 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/spark/blob/c9b233d4/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
index f438748..b3d04bf 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
@@ -1244,46 +1244,50 @@ case class Reverse(child: Expression) extends UnaryExpression with ImplicitCastI
}
private def arrayCodeGen(ctx: CodegenContext, ev: ExprCode, childName: String): String = {
- val length = ctx.freshName("length")
- val javaElementType = CodeGenerator.javaType(elementType)
+
val isPrimitiveType = CodeGenerator.isPrimitiveType(elementType)
+ val numElements = ctx.freshName("numElements")
+ val arrayData = ctx.freshName("arrayData")
+
val initialization = if (isPrimitiveType) {
- s"$childName.copy()"
+ ctx.createUnsafeArray(arrayData, numElements, elementType, s" $prettyName failed.")
} else {
- s"new ${classOf[GenericArrayData].getName()}(new Object[$length])"
- }
-
- val numberOfIterations = if (isPrimitiveType) s"$length / 2" else length
-
- val swapAssigments = if (isPrimitiveType) {
- val setFunc = "set" + CodeGenerator.primitiveTypeName(elementType)
- val getCall = (index: String) => CodeGenerator.getValue(ev.value, elementType, index)
- s"""|boolean isNullAtK = ${ev.value}.isNullAt(k);
- |boolean isNullAtL = ${ev.value}.isNullAt(l);
- |if(!isNullAtK) {
- | $javaElementType el = ${getCall("k")};
- | if(!isNullAtL) {
- | ${ev.value}.$setFunc(k, ${getCall("l")});
- | } else {
- | ${ev.value}.setNullAt(k);
- | }
- | ${ev.value}.$setFunc(l, el);
- |} else if (!isNullAtL) {
- | ${ev.value}.$setFunc(k, ${getCall("l")});
- | ${ev.value}.setNullAt(l);
- |}""".stripMargin
+ val arrayDataClass = classOf[GenericArrayData].getName
+ s"$arrayDataClass $arrayData = new $arrayDataClass(new Object[$numElements]);"
+ }
+
+ val i = ctx.freshName("i")
+ val j = ctx.freshName("j")
+
+ val getValue = CodeGenerator.getValue(childName, elementType, i)
+
+ val setFunc = if (isPrimitiveType) {
+ s"set${CodeGenerator.primitiveTypeName(elementType)}"
+ } else {
+ "update"
+ }
+
+ val assignment = if (isPrimitiveType && dataType.asInstanceOf[ArrayType].containsNull) {
+ s"""
+ |if ($childName.isNullAt($i)) {
+ | $arrayData.setNullAt($j);
+ |} else {
+ | $arrayData.$setFunc($j, $getValue);
+ |}
+ """.stripMargin
} else {
- s"${ev.value}.update(k, ${CodeGenerator.getValue(childName, elementType, "l")});"
+ s"$arrayData.$setFunc($j, $getValue);"
}
s"""
- |final int $length = $childName.numElements();
- |${ev.value} = $initialization;
- |for(int k = 0; k < $numberOfIterations; k++) {
- | int l = $length - k - 1;
- | $swapAssigments
+ |final int $numElements = $childName.numElements();
+ |$initialization
+ |for (int $i = 0; $i < $numElements; $i++) {
+ | int $j = $numElements - $i - 1;
+ | $assignment
|}
+ |${ev.value} = $arrayData;
""".stripMargin
}
http://git-wip-us.apache.org/repos/asf/spark/blob/c9b233d4/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
index bf04251..5a7bd45 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
@@ -901,8 +901,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext {
}
}
- test("reverse function") {
- // String test cases
+ test("reverse function - string") {
val oneRowDF = Seq(("Spark", 3215)).toDF("s", "i")
def testString(): Unit = {
checkAnswer(oneRowDF.select(reverse('s)), Seq(Row("krapS")))
@@ -917,37 +916,61 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext {
// Test with cached relation, the Project will be evaluated with codegen
oneRowDF.cache()
testString()
+ }
- // Array test cases (primitive-type elements)
- val idf = Seq(
+ test("reverse function - array for primitive type not containing null") {
+ val idfNotContainsNull = Seq(
Seq(1, 9, 8, 7),
Seq(5, 8, 9, 7, 2),
Seq.empty,
null
).toDF("i")
- def testArray(): Unit = {
+ def testArrayOfPrimitiveTypeNotContainsNull(): Unit = {
checkAnswer(
- idf.select(reverse('i)),
+ idfNotContainsNull.select(reverse('i)),
Seq(Row(Seq(7, 8, 9, 1)), Row(Seq(2, 7, 9, 8, 5)), Row(Seq.empty), Row(null))
)
checkAnswer(
- idf.selectExpr("reverse(i)"),
+ idfNotContainsNull.selectExpr("reverse(i)"),
Seq(Row(Seq(7, 8, 9, 1)), Row(Seq(2, 7, 9, 8, 5)), Row(Seq.empty), Row(null))
)
+ }
+
+ // Test with local relation, the Project will be evaluated without codegen
+ testArrayOfPrimitiveTypeNotContainsNull()
+ // Test with cached relation, the Project will be evaluated with codegen
+ idfNotContainsNull.cache()
+ testArrayOfPrimitiveTypeNotContainsNull()
+ }
+
+ test("reverse function - array for primitive type containing null") {
+ val idfContainsNull = Seq[Seq[Integer]](
+ Seq(1, 9, 8, null, 7),
+ Seq(null, 5, 8, 9, 7, 2),
+ Seq.empty,
+ null
+ ).toDF("i")
+
+ def testArrayOfPrimitiveTypeContainsNull(): Unit = {
+ checkAnswer(
+ idfContainsNull.select(reverse('i)),
+ Seq(Row(Seq(7, null, 8, 9, 1)), Row(Seq(2, 7, 9, 8, 5, null)), Row(Seq.empty), Row(null))
+ )
checkAnswer(
- idf.selectExpr("reverse(array(1, null, 2, null))"),
- Seq.fill(idf.count().toInt)(Row(Seq(null, 2, null, 1)))
+ idfContainsNull.selectExpr("reverse(i)"),
+ Seq(Row(Seq(7, null, 8, 9, 1)), Row(Seq(2, 7, 9, 8, 5, null)), Row(Seq.empty), Row(null))
)
}
// Test with local relation, the Project will be evaluated without codegen
- testArray()
+ testArrayOfPrimitiveTypeContainsNull()
// Test with cached relation, the Project will be evaluated with codegen
- idf.cache()
- testArray()
+ idfContainsNull.cache()
+ testArrayOfPrimitiveTypeContainsNull()
+ }
- // Array test cases (non-primitive-type elements)
+ test("reverse function - array for non-primitive type") {
val sdf = Seq(
Seq("c", "a", "b"),
Seq("b", null, "c", null),
@@ -975,14 +998,18 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext {
// Test with cached relation, the Project will be evaluated with codegen
sdf.cache()
testArrayOfNonPrimitiveType()
+ }
- // Error test cases
- intercept[AnalysisException] {
- oneRowDF.selectExpr("reverse(struct(1, 'a'))")
+ test("reverse function - data type mismatch") {
+ val ex1 = intercept[AnalysisException] {
+ sql("select reverse(struct(1, 'a'))")
}
- intercept[AnalysisException] {
- oneRowDF.selectExpr("reverse(map(1, 'a'))")
+ assert(ex1.getMessage.contains("data type mismatch"))
+
+ val ex2 = intercept[AnalysisException] {
+ sql("select reverse(map(1, 'a'))")
}
+ assert(ex2.getMessage.contains("data type mismatch"))
}
test("array position function") {
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org