You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ya...@apache.org on 2019/02/13 05:07:41 UTC

[spark] branch master updated: [SPARK-26798][SQL] HandleNullInputsForUDF should trust nullability

This is an automated email from the ASF dual-hosted git repository.

yamamuro pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new f502e20  [SPARK-26798][SQL] HandleNullInputsForUDF should trust nullability
f502e20 is described below

commit f502e209f49d3d76f947d1a8ba38c8c8a86e0bef
Author: Wenchen Fan <we...@databricks.com>
AuthorDate: Wed Feb 13 14:07:03 2019 +0900

    [SPARK-26798][SQL] HandleNullInputsForUDF should trust nullability
    
    ## What changes were proposed in this pull request?
    
    There is a very old TODO in `HandleNullInputsForUDF`, saying that we can skip the null check if input is not nullable. We leverage the nullability info at many places, we can trust it here too.
    
    ## How was this patch tested?
    
    re-enable an ignored test
    
    Closes #23712 from cloud-fan/minor.
    
    Lead-authored-by: Wenchen Fan <we...@databricks.com>
    Co-authored-by: Xiao Li <ga...@gmail.com>
    Signed-off-by: Takeshi Yamamuro <ya...@apache.org>
---
 .../spark/sql/catalyst/analysis/Analyzer.scala     | 47 ++++++++++++---------
 .../spark/sql/catalyst/expressions/ScalaUDF.scala  |  9 ++--
 .../sql/catalyst/analysis/AnalysisSuite.scala      | 36 +++++++++-------
 .../sql/catalyst/expressions/ScalaUDFSuite.scala   |  6 +--
 .../spark/sql/catalyst/trees/TreeNodeSuite.scala   |  2 +-
 .../org/apache/spark/sql/UDFRegistration.scala     | 48 +++++++++++-----------
 .../datasources/FileFormatDataWriter.scala         |  2 +-
 .../sql/expressions/UserDefinedFunction.scala      |  8 ++--
 .../test/scala/org/apache/spark/sql/UDFSuite.scala | 37 ++++++++++++-----
 9 files changed, 114 insertions(+), 81 deletions(-)

diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index a84bb76..793c337 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -2147,27 +2147,36 @@ class Analyzer(
 
       case p => p transformExpressionsUp {
 
-        case udf @ ScalaUDF(_, _, inputs, inputsNullSafe, _, _, _, _)
-            if inputsNullSafe.contains(false) =>
+        case udf @ ScalaUDF(_, _, inputs, inputPrimitives, _, _, _, _)
+            if inputPrimitives.contains(true) =>
           // Otherwise, add special handling of null for fields that can't accept null.
           // The result of operations like this, when passed null, is generally to return null.
-          assert(inputsNullSafe.length == inputs.length)
-
-          // TODO: skip null handling for not-nullable primitive inputs after we can completely
-          // trust the `nullable` information.
-          val inputsNullCheck = inputsNullSafe.zip(inputs)
-            .filter { case (nullSafe, _) => !nullSafe }
-            .map { case (_, expr) => IsNull(expr) }
-            .reduceLeftOption[Expression]((e1, e2) => Or(e1, e2))
-          // Once we add an `If` check above the udf, it is safe to mark those checked inputs
-          // as null-safe (i.e., set `inputsNullSafe` all `true`), because the null-returning
-          // branch of `If` will be called if any of these checked inputs is null. Thus we can
-          // prevent this rule from being applied repeatedly.
-          val newInputsNullSafe = inputsNullSafe.map(_ => true)
-          inputsNullCheck
-            .map(If(_, Literal.create(null, udf.dataType),
-              udf.copy(inputsNullSafe = newInputsNullSafe)))
-            .getOrElse(udf)
+          assert(inputPrimitives.length == inputs.length)
+
+          val inputPrimitivesPair = inputPrimitives.zip(inputs)
+          val inputNullCheck = inputPrimitivesPair.collect {
+            case (isPrimitive, input) if isPrimitive && input.nullable =>
+              IsNull(input)
+          }.reduceLeftOption[Expression](Or)
+
+          if (inputNullCheck.isDefined) {
+            // Once we add an `If` check above the udf, it is safe to mark those checked inputs
+            // as null-safe (i.e., wrap with `KnownNotNull`), because the null-returning
+            // branch of `If` will be called if any of these checked inputs is null. Thus we can
+            // prevent this rule from being applied repeatedly.
+            val newInputs = inputPrimitivesPair.map {
+              case (isPrimitive, input) =>
+                if (isPrimitive && input.nullable) {
+                  KnownNotNull(input)
+                } else {
+                  input
+                }
+            }
+            val newUDF = udf.copy(children = newInputs)
+            If(inputNullCheck.get, Literal.create(null, udf.dataType), newUDF)
+          } else {
+            udf
+          }
       }
     }
   }
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala
index c9e0a2e..eb45f08 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala
@@ -31,9 +31,10 @@ import org.apache.spark.sql.types.{AbstractDataType, DataType}
  *                  null. Use boxed type or [[Option]] if you wanna do the null-handling yourself.
  * @param dataType  Return type of function.
  * @param children  The input expressions of this UDF.
- * @param inputsNullSafe Whether the inputs are of non-primitive types or not nullable. Null values
- *                       of Scala primitive types will be converted to the type's default value and
- *                       lead to wrong results, thus need special handling before calling the UDF.
+ * @param inputPrimitives The analyzer should be aware of Scala primitive types so as to make the
+ *                        UDF return null if there is any null input value of these types. On the
+ *                        other hand, Java UDFs can only have boxed types, thus this parameter will
+ *                        always be all false.
  * @param inputTypes  The expected input types of this UDF, used to perform type coercion. If we do
  *                    not want to perform coercion, simply use "Nil". Note that it would've been
  *                    better to use Option of Seq[DataType] so we can use "None" as the case for no
@@ -47,7 +48,7 @@ case class ScalaUDF(
     function: AnyRef,
     dataType: DataType,
     children: Seq[Expression],
-    inputsNullSafe: Seq[Boolean],
+    inputPrimitives: Seq[Boolean],
     inputTypes: Seq[AbstractDataType] = Nil,
     udfName: Option[String] = None,
     nullable: Boolean = true,
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
index 9829484..8038733 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
@@ -303,51 +303,57 @@ class AnalysisSuite extends AnalysisTest with Matchers {
   }
 
   test("SPARK-11725: correctly handle null inputs for ScalaUDF") {
-    val string = testRelation2.output(0)
-    val double = testRelation2.output(2)
-    val short = testRelation2.output(4)
+    val testRelation = LocalRelation(
+      AttributeReference("a", StringType)(),
+      AttributeReference("b", DoubleType)(),
+      AttributeReference("c", ShortType)(),
+      AttributeReference("d", DoubleType, nullable = false)())
+
+    val string = testRelation.output(0)
+    val double = testRelation.output(1)
+    val short = testRelation.output(2)
+    val nonNullableDouble = testRelation.output(3)
     val nullResult = Literal.create(null, StringType)
 
     def checkUDF(udf: Expression, transformed: Expression): Unit = {
       checkAnalysis(
-        Project(Alias(udf, "")() :: Nil, testRelation2),
-        Project(Alias(transformed, "")() :: Nil, testRelation2)
+        Project(Alias(udf, "")() :: Nil, testRelation),
+        Project(Alias(transformed, "")() :: Nil, testRelation)
       )
     }
 
     // non-primitive parameters do not need special null handling
-    val udf1 = ScalaUDF((s: String) => "x", StringType, string :: Nil, true :: Nil)
+    val udf1 = ScalaUDF((s: String) => "x", StringType, string :: Nil, false :: Nil)
     val expected1 = udf1
     checkUDF(udf1, expected1)
 
     // only primitive parameter needs special null handling
     val udf2 = ScalaUDF((s: String, d: Double) => "x", StringType, string :: double :: Nil,
-      true :: false :: Nil)
+      false :: true :: Nil)
     val expected2 =
-      If(IsNull(double), nullResult, udf2.copy(inputsNullSafe = true :: true :: Nil))
+      If(IsNull(double), nullResult, udf2.copy(children = string :: KnownNotNull(double) :: Nil))
     checkUDF(udf2, expected2)
 
     // special null handling should apply to all primitive parameters
     val udf3 = ScalaUDF((s: Short, d: Double) => "x", StringType, short :: double :: Nil,
-      false :: false :: Nil)
+      true :: true :: Nil)
     val expected3 = If(
       IsNull(short) || IsNull(double),
       nullResult,
-      udf3.copy(inputsNullSafe = true :: true :: Nil))
+      udf3.copy(children = KnownNotNull(short) :: KnownNotNull(double) :: Nil))
     checkUDF(udf3, expected3)
 
     // we can skip special null handling for primitive parameters that are not nullable
-    // TODO: this is disabled for now as we can not completely trust `nullable`.
     val udf4 = ScalaUDF(
       (s: Short, d: Double) => "x",
       StringType,
-      short :: double.withNullability(false) :: Nil,
-      false :: false :: Nil)
+      short :: nonNullableDouble :: Nil,
+      true :: true :: Nil)
     val expected4 = If(
       IsNull(short),
       nullResult,
-      udf4.copy(inputsNullSafe = true :: true :: Nil))
-    // checkUDF(udf4, expected4)
+      udf4.copy(children = KnownNotNull(short) :: nonNullableDouble :: Nil))
+    checkUDF(udf4, expected4)
   }
 
   test("SPARK-24891 Fix HandleNullInputsForUDF rule") {
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDFSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDFSuite.scala
index 467cfd5..df92fa3 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDFSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDFSuite.scala
@@ -29,7 +29,7 @@ class ScalaUDFSuite extends SparkFunSuite with ExpressionEvalHelper {
     val intUdf = ScalaUDF((i: Int) => i + 1, IntegerType, Literal(1) :: Nil, true :: Nil)
     checkEvaluation(intUdf, 2)
 
-    val stringUdf = ScalaUDF((s: String) => s + "x", StringType, Literal("a") :: Nil, true :: Nil)
+    val stringUdf = ScalaUDF((s: String) => s + "x", StringType, Literal("a") :: Nil, false :: Nil)
     checkEvaluation(stringUdf, "ax")
   }
 
@@ -38,7 +38,7 @@ class ScalaUDFSuite extends SparkFunSuite with ExpressionEvalHelper {
       (s: String) => s.toLowerCase(Locale.ROOT),
       StringType,
       Literal.create(null, StringType) :: Nil,
-      true :: Nil)
+      false :: Nil)
 
     val e1 = intercept[SparkException](udf.eval())
     assert(e1.getMessage.contains("Failed to execute user defined function"))
@@ -51,7 +51,7 @@ class ScalaUDFSuite extends SparkFunSuite with ExpressionEvalHelper {
 
   test("SPARK-22695: ScalaUDF should not use global variables") {
     val ctx = new CodegenContext
-    ScalaUDF((s: String) => s + "x", StringType, Literal("a") :: Nil, true :: Nil).genCode(ctx)
+    ScalaUDF((s: String) => s + "x", StringType, Literal("a") :: Nil, false :: Nil).genCode(ctx)
     assert(ctx.inlinedMutableStates.isEmpty)
   }
 }
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala
index 64aa1ee..cb911d7 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala
@@ -564,7 +564,7 @@ class TreeNodeSuite extends SparkFunSuite {
   }
 
   test("toJSON should not throws java.lang.StackOverflowError") {
-    val udf = ScalaUDF(SelfReferenceUDF(), BooleanType, Seq("col1".attr), true :: Nil)
+    val udf = ScalaUDF(SelfReferenceUDF(), BooleanType, Seq("col1".attr), false :: Nil)
     // Should not throw java.lang.StackOverflowError
     udf.toJSON
   }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala
index fe5d1af..83425de 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala
@@ -151,7 +151,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
         |def register(name: String, f: UDF$i[$extTypeArgs], returnType: DataType): Unit = {
         |  val func = f$anyCast.call($anyParams)
         |  def builder(e: Seq[Expression]) = if (e.length == $i) {
-        |    ScalaUDF($funcCall, returnType, e, e.map(_ => true), udfName = Some(name))
+        |    ScalaUDF($funcCall, returnType, e, e.map(_ => false), udfName = Some(name))
         |  } else {
         |    throw new AnalysisException("Invalid number of arguments for function " + name +
         |      ". Expected: $i; Found: " + e.length)
@@ -719,7 +719,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
   def register(name: String, f: UDF0[_], returnType: DataType): Unit = {
     val func = f.asInstanceOf[UDF0[Any]].call()
     def builder(e: Seq[Expression]) = if (e.length == 0) {
-      ScalaUDF(() => func, returnType, e, e.map(_ => true), udfName = Some(name))
+      ScalaUDF(() => func, returnType, e, e.map(_ => false), udfName = Some(name))
     } else {
       throw new AnalysisException("Invalid number of arguments for function " + name +
         ". Expected: 0; Found: " + e.length)
@@ -734,7 +734,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
   def register(name: String, f: UDF1[_, _], returnType: DataType): Unit = {
     val func = f.asInstanceOf[UDF1[Any, Any]].call(_: Any)
     def builder(e: Seq[Expression]) = if (e.length == 1) {
-      ScalaUDF(func, returnType, e, e.map(_ => true), udfName = Some(name))
+      ScalaUDF(func, returnType, e, e.map(_ => false), udfName = Some(name))
     } else {
       throw new AnalysisException("Invalid number of arguments for function " + name +
         ". Expected: 1; Found: " + e.length)
@@ -749,7 +749,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
   def register(name: String, f: UDF2[_, _, _], returnType: DataType): Unit = {
     val func = f.asInstanceOf[UDF2[Any, Any, Any]].call(_: Any, _: Any)
     def builder(e: Seq[Expression]) = if (e.length == 2) {
-      ScalaUDF(func, returnType, e, e.map(_ => true), udfName = Some(name))
+      ScalaUDF(func, returnType, e, e.map(_ => false), udfName = Some(name))
     } else {
       throw new AnalysisException("Invalid number of arguments for function " + name +
         ". Expected: 2; Found: " + e.length)
@@ -764,7 +764,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
   def register(name: String, f: UDF3[_, _, _, _], returnType: DataType): Unit = {
     val func = f.asInstanceOf[UDF3[Any, Any, Any, Any]].call(_: Any, _: Any, _: Any)
     def builder(e: Seq[Expression]) = if (e.length == 3) {
-      ScalaUDF(func, returnType, e, e.map(_ => true), udfName = Some(name))
+      ScalaUDF(func, returnType, e, e.map(_ => false), udfName = Some(name))
     } else {
       throw new AnalysisException("Invalid number of arguments for function " + name +
         ". Expected: 3; Found: " + e.length)
@@ -779,7 +779,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
   def register(name: String, f: UDF4[_, _, _, _, _], returnType: DataType): Unit = {
     val func = f.asInstanceOf[UDF4[Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any)
     def builder(e: Seq[Expression]) = if (e.length == 4) {
-      ScalaUDF(func, returnType, e, e.map(_ => true), udfName = Some(name))
+      ScalaUDF(func, returnType, e, e.map(_ => false), udfName = Some(name))
     } else {
       throw new AnalysisException("Invalid number of arguments for function " + name +
         ". Expected: 4; Found: " + e.length)
@@ -794,7 +794,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
   def register(name: String, f: UDF5[_, _, _, _, _, _], returnType: DataType): Unit = {
     val func = f.asInstanceOf[UDF5[Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any)
     def builder(e: Seq[Expression]) = if (e.length == 5) {
-      ScalaUDF(func, returnType, e, e.map(_ => true), udfName = Some(name))
+      ScalaUDF(func, returnType, e, e.map(_ => false), udfName = Some(name))
     } else {
       throw new AnalysisException("Invalid number of arguments for function " + name +
         ". Expected: 5; Found: " + e.length)
@@ -809,7 +809,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
   def register(name: String, f: UDF6[_, _, _, _, _, _, _], returnType: DataType): Unit = {
     val func = f.asInstanceOf[UDF6[Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any)
     def builder(e: Seq[Expression]) = if (e.length == 6) {
-      ScalaUDF(func, returnType, e, e.map(_ => true), udfName = Some(name))
+      ScalaUDF(func, returnType, e, e.map(_ => false), udfName = Some(name))
     } else {
       throw new AnalysisException("Invalid number of arguments for function " + name +
         ". Expected: 6; Found: " + e.length)
@@ -824,7 +824,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
   def register(name: String, f: UDF7[_, _, _, _, _, _, _, _], returnType: DataType): Unit = {
     val func = f.asInstanceOf[UDF7[Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any)
     def builder(e: Seq[Expression]) = if (e.length == 7) {
-      ScalaUDF(func, returnType, e, e.map(_ => true), udfName = Some(name))
+      ScalaUDF(func, returnType, e, e.map(_ => false), udfName = Some(name))
     } else {
       throw new AnalysisException("Invalid number of arguments for function " + name +
         ". Expected: 7; Found: " + e.length)
@@ -839,7 +839,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
   def register(name: String, f: UDF8[_, _, _, _, _, _, _, _, _], returnType: DataType): Unit = {
     val func = f.asInstanceOf[UDF8[Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any)
     def builder(e: Seq[Expression]) = if (e.length == 8) {
-      ScalaUDF(func, returnType, e, e.map(_ => true), udfName = Some(name))
+      ScalaUDF(func, returnType, e, e.map(_ => false), udfName = Some(name))
     } else {
       throw new AnalysisException("Invalid number of arguments for function " + name +
         ". Expected: 8; Found: " + e.length)
@@ -854,7 +854,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
   def register(name: String, f: UDF9[_, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = {
     val func = f.asInstanceOf[UDF9[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any)
     def builder(e: Seq[Expression]) = if (e.length == 9) {
-      ScalaUDF(func, returnType, e, e.map(_ => true), udfName = Some(name))
+      ScalaUDF(func, returnType, e, e.map(_ => false), udfName = Some(name))
     } else {
       throw new AnalysisException("Invalid number of arguments for function " + name +
         ". Expected: 9; Found: " + e.length)
@@ -869,7 +869,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
   def register(name: String, f: UDF10[_, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = {
     val func = f.asInstanceOf[UDF10[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any)
     def builder(e: Seq[Expression]) = if (e.length == 10) {
-      ScalaUDF(func, returnType, e, e.map(_ => true), udfName = Some(name))
+      ScalaUDF(func, returnType, e, e.map(_ => false), udfName = Some(name))
     } else {
       throw new AnalysisException("Invalid number of arguments for function " + name +
         ". Expected: 10; Found: " + e.length)
@@ -884,7 +884,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
   def register(name: String, f: UDF11[_, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = {
     val func = f.asInstanceOf[UDF11[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any)
     def builder(e: Seq[Expression]) = if (e.length == 11) {
-      ScalaUDF(func, returnType, e, e.map(_ => true), udfName = Some(name))
+      ScalaUDF(func, returnType, e, e.map(_ => false), udfName = Some(name))
     } else {
       throw new AnalysisException("Invalid number of arguments for function " + name +
         ". Expected: 11; Found: " + e.length)
@@ -899,7 +899,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
   def register(name: String, f: UDF12[_, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = {
     val func = f.asInstanceOf[UDF12[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any)
     def builder(e: Seq[Expression]) = if (e.length == 12) {
-      ScalaUDF(func, returnType, e, e.map(_ => true), udfName = Some(name))
+      ScalaUDF(func, returnType, e, e.map(_ => false), udfName = Some(name))
     } else {
       throw new AnalysisException("Invalid number of arguments for function " + name +
         ". Expected: 12; Found: " + e.length)
@@ -914,7 +914,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
   def register(name: String, f: UDF13[_, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = {
     val func = f.asInstanceOf[UDF13[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any)
     def builder(e: Seq[Expression]) = if (e.length == 13) {
-      ScalaUDF(func, returnType, e, e.map(_ => true), udfName = Some(name))
+      ScalaUDF(func, returnType, e, e.map(_ => false), udfName = Some(name))
     } else {
       throw new AnalysisException("Invalid number of arguments for function " + name +
         ". Expected: 13; Found: " + e.length)
@@ -929,7 +929,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
   def register(name: String, f: UDF14[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = {
     val func = f.asInstanceOf[UDF14[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any)
     def builder(e: Seq[Expression]) = if (e.length == 14) {
-      ScalaUDF(func, returnType, e, e.map(_ => true), udfName = Some(name))
+      ScalaUDF(func, returnType, e, e.map(_ => false), udfName = Some(name))
     } else {
       throw new AnalysisException("Invalid number of arguments for function " + name +
         ". Expected: 14; Found: " + e.length)
@@ -944,7 +944,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
   def register(name: String, f: UDF15[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = {
     val func = f.asInstanceOf[UDF15[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any)
     def builder(e: Seq[Expression]) = if (e.length == 15) {
-      ScalaUDF(func, returnType, e, e.map(_ => true), udfName = Some(name))
+      ScalaUDF(func, returnType, e, e.map(_ => false), udfName = Some(name))
     } else {
       throw new AnalysisException("Invalid number of arguments for function " + name +
         ". Expected: 15; Found: " + e.length)
@@ -959,7 +959,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
   def register(name: String, f: UDF16[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = {
     val func = f.asInstanceOf[UDF16[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any)
     def builder(e: Seq[Expression]) = if (e.length == 16) {
-      ScalaUDF(func, returnType, e, e.map(_ => true), udfName = Some(name))
+      ScalaUDF(func, returnType, e, e.map(_ => false), udfName = Some(name))
     } else {
       throw new AnalysisException("Invalid number of arguments for function " + name +
         ". Expected: 16; Found: " + e.length)
@@ -974,7 +974,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
   def register(name: String, f: UDF17[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = {
     val func = f.asInstanceOf[UDF17[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any)
     def builder(e: Seq[Expression]) = if (e.length == 17) {
-      ScalaUDF(func, returnType, e, e.map(_ => true), udfName = Some(name))
+      ScalaUDF(func, returnType, e, e.map(_ => false), udfName = Some(name))
     } else {
       throw new AnalysisException("Invalid number of arguments for function " + name +
         ". Expected: 17; Found: " + e.length)
@@ -989,7 +989,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
   def register(name: String, f: UDF18[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = {
     val func = f.asInstanceOf[UDF18[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any)
     def builder(e: Seq[Expression]) = if (e.length == 18) {
-      ScalaUDF(func, returnType, e, e.map(_ => true), udfName = Some(name))
+      ScalaUDF(func, returnType, e, e.map(_ => false), udfName = Some(name))
     } else {
       throw new AnalysisException("Invalid number of arguments for function " + name +
         ". Expected: 18; Found: " + e.length)
@@ -1004,7 +1004,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
   def register(name: String, f: UDF19[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = {
     val func = f.asInstanceOf[UDF19[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any)
     def builder(e: Seq[Expression]) = if (e.length == 19) {
-      ScalaUDF(func, returnType, e, e.map(_ => true), udfName = Some(name))
+      ScalaUDF(func, returnType, e, e.map(_ => false), udfName = Some(name))
     } else {
       throw new AnalysisException("Invalid number of arguments for function " + name +
         ". Expected: 19; Found: " + e.length)
@@ -1019,7 +1019,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
   def register(name: String, f: UDF20[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = {
     val func = f.asInstanceOf[UDF20[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any)
     def builder(e: Seq[Expression]) = if (e.length == 20) {
-      ScalaUDF(func, returnType, e, e.map(_ => true), udfName = Some(name))
+      ScalaUDF(func, returnType, e, e.map(_ => false), udfName = Some(name))
     } else {
       throw new AnalysisException("Invalid number of arguments for function " + name +
         ". Expected: 20; Found: " + e.length)
@@ -1034,7 +1034,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
   def register(name: String, f: UDF21[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = {
     val func = f.asInstanceOf[UDF21[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any)
     def builder(e: Seq[Expression]) = if (e.length == 21) {
-      ScalaUDF(func, returnType, e, e.map(_ => true), udfName = Some(name))
+      ScalaUDF(func, returnType, e, e.map(_ => false), udfName = Some(name))
     } else {
       throw new AnalysisException("Invalid number of arguments for function " + name +
         ". Expected: 21; Found: " + e.length)
@@ -1049,7 +1049,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
   def register(name: String, f: UDF22[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = {
     val func = f.asInstanceOf[UDF22[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any)
     def builder(e: Seq[Expression]) = if (e.length == 22) {
-      ScalaUDF(func, returnType, e, e.map(_ => true), udfName = Some(name))
+      ScalaUDF(func, returnType, e, e.map(_ => false), udfName = Some(name))
     } else {
       throw new AnalysisException("Invalid number of arguments for function " + name +
         ". Expected: 22; Found: " + e.length)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatDataWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatDataWriter.scala
index b6b0d7a..2595cc6 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatDataWriter.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatDataWriter.scala
@@ -181,7 +181,7 @@ class DynamicPartitionDataWriter(
         ExternalCatalogUtils.getPartitionPathString _,
         StringType,
         Seq(Literal(c.name), Cast(c, StringType, Option(description.timeZoneId))),
-        Seq(true, true))
+        Seq(false, false))
       if (i == 0) Seq(partitionName) else Seq(Literal(Path.SEPARATOR), partitionName)
     })
 
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala
index 4d8e1c5..0c956ec 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala
@@ -100,14 +100,16 @@ private[sql] case class SparkUserDefinedFunction(
 
   private[sql] def createScalaUDF(exprs: Seq[Expression]): ScalaUDF = {
     // It's possible that some of the inputs don't have a specific type(e.g. `Any`),  skip type
-    // check and null check for them.
+    // check.
     val inputTypes = inputSchemas.map(_.map(_.dataType).getOrElse(AnyDataType))
-    val inputsNullSafe = inputSchemas.map(_.map(_.nullable).getOrElse(true))
+    // `ScalaReflection.Schema.nullable` is false iff the type is primitive. Also `Any` is not
+    // primitive.
+    val inputsPrimitive = inputSchemas.map(_.map(!_.nullable).getOrElse(false))
     ScalaUDF(
       f,
       dataType,
       exprs,
-      inputsNullSafe,
+      inputsPrimitive,
       inputTypes,
       udfName = name,
       nullable = nullable,
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala
index 5ac2093..e515800 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala
@@ -400,17 +400,23 @@ class UDFSuite extends QueryTest with SharedSQLContext {
   }
 
   test("SPARK-25044 Verify null input handling for primitive types - with udf()") {
-    val udf1 = udf((x: Long, y: Any) => x * 2 + (if (y == null) 1 else 0))
-    val df = spark.range(0, 3).toDF("a")
-      .withColumn("b", udf1($"a", lit(null)))
-      .withColumn("c", udf1(lit(null), $"a"))
-
-    checkAnswer(
-      df,
-      Seq(
-        Row(0, 1, null),
-        Row(1, 3, null),
-        Row(2, 5, null)))
+    val input = Seq(
+      (null, Integer.valueOf(1), "x"),
+      ("M", null, "y"),
+      ("N", Integer.valueOf(3), null)).toDF("a", "b", "c")
+
+    val udf1 = udf((a: String, b: Int, c: Any) => a + b + c)
+    val df = input.select(udf1('a, 'b, 'c))
+    checkAnswer(df, Seq(Row("null1x"), Row(null), Row("N3null")))
+
+    // test Java UDF. Java UDF can't have primitive inputs, as it's generic typed.
+    val udf2 = udf(new UDF3[String, Integer, Object, String] {
+      override def call(t1: String, t2: Integer, t3: Object): String = {
+        t1 + t2 + t3
+      }
+    }, StringType)
+    val df2 = input.select(udf2('a, 'b, 'c))
+    checkAnswer(df2, Seq(Row("null1x"), Row("Mnully"), Row("N3null")))
   }
 
   test("SPARK-25044 Verify null input handling for primitive types - with udf.register") {
@@ -420,6 +426,15 @@ class UDFSuite extends QueryTest with SharedSQLContext {
       spark.udf.register("f", (a: String, b: Int, c: Any) => a + b + c)
       val df = spark.sql("SELECT f(a, b, c) FROM t")
       checkAnswer(df, Seq(Row("null1x"), Row(null), Row("N3null")))
+
+      // test Java UDF. Java UDF can't have primitive inputs, as it's generic typed.
+      spark.udf.register("f2", new UDF3[String, Integer, Object, String] {
+        override def call(t1: String, t2: Integer, t3: Object): String = {
+          t1 + t2 + t3
+        }
+      }, StringType)
+      val df2 = spark.sql("SELECT f2(a, b, c) FROM t")
+      checkAnswer(df2, Seq(Row("null1x"), Row("Mnully"), Row("N3null")))
     }
   }
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org