You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ge...@apache.org on 2023/02/16 00:26:36 UTC

[spark] branch master updated: [SPARK-42384][SQL] Check for null input in generated code for mask function

This is an automated email from the ASF dual-hosted git repository.

gengliang pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 7ff8ba29f2c [SPARK-42384][SQL] Check for null input in generated code for mask function
7ff8ba29f2c is described below

commit 7ff8ba29f2cb03686fb79c44e82223485e480ea4
Author: Bruce Robbins <be...@gmail.com>
AuthorDate: Wed Feb 15 16:26:23 2023 -0800

    [SPARK-42384][SQL] Check for null input in generated code for mask function
    
    ### What changes were proposed in this pull request?
    
    When generating code for the mask function, call `ctx.nullSafeExec` to produce null safe code.
    
    This change assumes that the mask function returns null only when the input is null (which appears to be the case, from reading the code of `Mask.transformInput`).
    
    ### Why are the changes needed?
    
    The following query fails with a `NullPointerException`:
    ```
    create or replace temp view v1 as
    select * from values
    (null),
    ('AbCD123-$#')
    as data(col1);
    
    cache table v1;
    
    select mask(col1) from v1;
    
    23/02/07 16:36:06 ERROR Executor: Exception in task 0.0 in stage 3.0 (TID 3)
    java.lang.NullPointerException
            at org.apache.spark.sql.catalyst.expressions.codegen.UnsafeWriter.write(UnsafeWriter.java:110)
            at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage1.processNext(Unknown Source)
            at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43)
            at org.apache.spark.sql.execution.WholeStageCodegenExec$$anon$1.hasNext(WholeStageCodegenExec.scala:760)
    ```
    The generated code calls `UnsafeWriter.write(0, value_0)` regardless of whether `Mask.transformInput` returns null or not. The `UnsafeWriter.write` method for `UTF8String` does not expect a null pointer.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No.
    
    ### How was this patch tested?
    
    New unit tests.
    
    Closes #39945 from bersprockets/mask_npe_issue.
    
    Authored-by: Bruce Robbins <be...@gmail.com>
    Signed-off-by: Gengliang Wang <ge...@apache.org>
---
 .../sql/catalyst/expressions/maskExpressions.scala | 22 +++++++++++++++++++---
 .../expressions/StringExpressionsSuite.scala       |  6 ++++++
 .../apache/spark/sql/StringFunctionsSuite.scala    | 12 ++++++++++++
 3 files changed, 37 insertions(+), 3 deletions(-)

diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/maskExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/maskExpressions.scala
index e2828f35232..af74e7c0f7b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/maskExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/maskExpressions.scala
@@ -223,8 +223,23 @@ case class Mask(
     val fifthGen = children(4).genCode(ctx)
     val resultCode =
       f(firstGen.value, secondGen.value, thirdGen.value, fourthGen.value, fifthGen.value)
-    ev.copy(
-      code = code"""
+    if (nullable) {
+      // this function is somewhat like a `UnaryExpression`, in that only the first child
+      // determines whether the result is null
+      val nullSafeEval = ctx.nullSafeExec(children(0).nullable, firstGen.isNull)(resultCode)
+      ev.copy(code = code"""
+        ${firstGen.code}
+        ${secondGen.code}
+        ${thirdGen.code}
+        ${fourthGen.code}
+        ${fifthGen.code}
+        boolean ${ev.isNull} = ${firstGen.isNull};
+        ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
+        $nullSafeEval
+      """)
+    } else {
+      ev.copy(
+        code = code"""
         ${firstGen.code}
         ${secondGen.code}
         ${thirdGen.code}
@@ -232,7 +247,8 @@ case class Mask(
         ${fifthGen.code}
         ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
         $resultCode""",
-      isNull = FalseLiteral)
+        isNull = FalseLiteral)
+    }
   }
 
   /**
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala
index 11ec7f6babc..017f4483e88 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala
@@ -440,6 +440,12 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
     }
   }
 
+  test("SPARK-42384: Mask with null input") {
+    val NULL_LITERAL = Literal(null, StringType)
+    checkEvaluation(
+      new Mask(NULL_LITERAL, Literal('Q'), Literal('q'), Literal('d')), null)
+  }
+
   test("string for ascii") {
     val a = $"a".long.at(0)
     checkEvaluation(Chr(Literal(48L)), "0", create_row("abdef"))
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala
index 3c4ad7c1dca..23e71bb2f49 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala
@@ -702,4 +702,16 @@ class StringFunctionsSuite extends QueryTest with SharedSparkSession {
       )
     )
   }
+
+  test("SPARK-42384: mask with null input") {
+    val df = Seq(
+      ("AbCD123-@$#"),
+      (null)
+    ).toDF("a")
+
+    checkAnswer(
+      df.selectExpr("mask(a,'Q','q','d','o')"),
+      Row("QqQQdddoooo") :: Row(null) :: Nil
+    )
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org