You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ge...@apache.org on 2023/02/16 00:26:36 UTC
[spark] branch master updated: [SPARK-42384][SQL] Check for null input in generated code for mask function
This is an automated email from the ASF dual-hosted git repository.
gengliang pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 7ff8ba29f2c [SPARK-42384][SQL] Check for null input in generated code for mask function
7ff8ba29f2c is described below
commit 7ff8ba29f2cb03686fb79c44e82223485e480ea4
Author: Bruce Robbins <be...@gmail.com>
AuthorDate: Wed Feb 15 16:26:23 2023 -0800
[SPARK-42384][SQL] Check for null input in generated code for mask function
### What changes were proposed in this pull request?
When generating code for the mask function, call `ctx.nullSafeExec` to produce null safe code.
This change assumes that the mask function returns null only when the input is null (which appears to be the case, from reading the code of `Mask.transformInput`).
### Why are the changes needed?
The following query fails with a `NullPointerException`:
```
create or replace temp view v1 as
select * from values
(null),
('AbCD123-$#')
as data(col1);
cache table v1;
select mask(col1) from v1;
23/02/07 16:36:06 ERROR Executor: Exception in task 0.0 in stage 3.0 (TID 3)
java.lang.NullPointerException
at org.apache.spark.sql.catalyst.expressions.codegen.UnsafeWriter.write(UnsafeWriter.java:110)
at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage1.processNext(Unknown Source)
at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43)
at org.apache.spark.sql.execution.WholeStageCodegenExec$$anon$1.hasNext(WholeStageCodegenExec.scala:760)
```
The generated code calls `UnsafeWriter.write(0, value_0)` regardless of whether `Mask.transformInput` returns null or not. The `UnsafeWriter.write` method for `UTF8String` does not expect a null pointer.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
New unit tests.
Closes #39945 from bersprockets/mask_npe_issue.
Authored-by: Bruce Robbins <be...@gmail.com>
Signed-off-by: Gengliang Wang <ge...@apache.org>
---
.../sql/catalyst/expressions/maskExpressions.scala | 22 +++++++++++++++++++---
.../expressions/StringExpressionsSuite.scala | 6 ++++++
.../apache/spark/sql/StringFunctionsSuite.scala | 12 ++++++++++++
3 files changed, 37 insertions(+), 3 deletions(-)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/maskExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/maskExpressions.scala
index e2828f35232..af74e7c0f7b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/maskExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/maskExpressions.scala
@@ -223,8 +223,23 @@ case class Mask(
val fifthGen = children(4).genCode(ctx)
val resultCode =
f(firstGen.value, secondGen.value, thirdGen.value, fourthGen.value, fifthGen.value)
- ev.copy(
- code = code"""
+ if (nullable) {
+ // this function is somewhat like a `UnaryExpression`, in that only the first child
+ // determines whether the result is null
+ val nullSafeEval = ctx.nullSafeExec(children(0).nullable, firstGen.isNull)(resultCode)
+ ev.copy(code = code"""
+ ${firstGen.code}
+ ${secondGen.code}
+ ${thirdGen.code}
+ ${fourthGen.code}
+ ${fifthGen.code}
+ boolean ${ev.isNull} = ${firstGen.isNull};
+ ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
+ $nullSafeEval
+ """)
+ } else {
+ ev.copy(
+ code = code"""
${firstGen.code}
${secondGen.code}
${thirdGen.code}
@@ -232,7 +247,8 @@ case class Mask(
${fifthGen.code}
${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
$resultCode""",
- isNull = FalseLiteral)
+ isNull = FalseLiteral)
+ }
}
/**
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala
index 11ec7f6babc..017f4483e88 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala
@@ -440,6 +440,12 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
}
}
+ test("SPARK-42384: Mask with null input") {
+ val NULL_LITERAL = Literal(null, StringType)
+ checkEvaluation(
+ new Mask(NULL_LITERAL, Literal('Q'), Literal('q'), Literal('d')), null)
+ }
+
test("string for ascii") {
val a = $"a".long.at(0)
checkEvaluation(Chr(Literal(48L)), "0", create_row("abdef"))
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala
index 3c4ad7c1dca..23e71bb2f49 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala
@@ -702,4 +702,16 @@ class StringFunctionsSuite extends QueryTest with SharedSparkSession {
)
)
}
+
+ test("SPARK-42384: mask with null input") {
+ val df = Seq(
+ ("AbCD123-@$#"),
+ (null)
+ ).toDF("a")
+
+ checkAnswer(
+ df.selectExpr("mask(a,'Q','q','d','o')"),
+ Row("QqQQdddoooo") :: Row(null) :: Nil
+ )
+ }
}
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org