You are viewing a plain text version of this content. The canonical link for it is here.
Posted to reviews@spark.apache.org by GitBox <gi...@apache.org> on 2019/02/04 08:00:12 UTC

[GitHub] gatorsmile commented on a change in pull request #23712: [SPARK-26798][SQL] HandleNullInputsForUDF should trust nullability

gatorsmile commented on a change in pull request #23712: [SPARK-26798][SQL] HandleNullInputsForUDF should trust nullability
URL: https://github.com/apache/spark/pull/23712#discussion_r253370793
 
 

 ##########
 File path: sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
 ##########
 @@ -2147,27 +2147,34 @@ class Analyzer(
 
       case p => p transformExpressionsUp {
 
-        case udf @ ScalaUDF(_, _, inputs, inputsNullSafe, _, _, _, _)
-            if inputsNullSafe.contains(false) =>
+        case udf @ ScalaUDF(_, _, inputs, primitiveInputs, _, _, _, _)
+            if primitiveInputs.contains(true) =>
           // Otherwise, add special handling of null for fields that can't accept null.
           // The result of operations like this, when passed null, is generally to return null.
-          assert(inputsNullSafe.length == inputs.length)
-
-          // TODO: skip null handling for not-nullable primitive inputs after we can completely
-          // trust the `nullable` information.
-          val inputsNullCheck = inputsNullSafe.zip(inputs)
-            .filter { case (nullSafe, _) => !nullSafe }
-            .map { case (_, expr) => IsNull(expr) }
-            .reduceLeftOption[Expression]((e1, e2) => Or(e1, e2))
-          // Once we add an `If` check above the udf, it is safe to mark those checked inputs
-          // as null-safe (i.e., set `inputsNullSafe` all `true`), because the null-returning
-          // branch of `If` will be called if any of these checked inputs is null. Thus we can
-          // prevent this rule from being applied repeatedly.
-          val newInputsNullSafe = inputsNullSafe.map(_ => true)
-          inputsNullCheck
-            .map(If(_, Literal.create(null, udf.dataType),
-              udf.copy(inputsNullSafe = newInputsNullSafe)))
-            .getOrElse(udf)
+          assert(primitiveInputs.length == inputs.length)
+
+          val inputNullCheck = primitiveInputs.zip(inputs).collect {
+            case (isPrimitive, input) if isPrimitive && input.nullable => IsNull(input)
 
 Review comment:
   This is the major change made in this PR. 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org