You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by da...@apache.org on 2015/12/11 20:15:56 UTC

spark git commit: [SPARK-12258] [SQL] passing null into ScalaUDF (follow-up)

Repository: spark
Updated Branches:
  refs/heads/master 518ab5101 -> c119a34d1


[SPARK-12258] [SQL] passing null into ScalaUDF (follow-up)

This is a follow-up PR for #10259

Author: Davies Liu <da...@databricks.com>

Closes #10266 from davies/null_udf2.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/c119a34d
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/c119a34d
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/c119a34d

Branch: refs/heads/master
Commit: c119a34d1e9e599e302acfda92e5de681086a19f
Parents: 518ab51
Author: Davies Liu <da...@databricks.com>
Authored: Fri Dec 11 11:15:53 2015 -0800
Committer: Davies Liu <da...@gmail.com>
Committed: Fri Dec 11 11:15:53 2015 -0800

----------------------------------------------------------------------
 .../sql/catalyst/expressions/ScalaUDF.scala     | 31 +++++++++++---------
 .../org/apache/spark/sql/DataFrameSuite.scala   |  8 +++--
 2 files changed, 23 insertions(+), 16 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/c119a34d/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala
index 5deb2f8..85faa19 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala
@@ -1029,24 +1029,27 @@ case class ScalaUDF(
     // such as IntegerType, its javaType is `int` and the returned type of user-defined
     // function is Object. Trying to convert an Object to `int` will cause casting exception.
     val evalCode = evals.map(_.code).mkString
-    val funcArguments = converterTerms.zipWithIndex.map {
-      case (converter, i) =>
-        val eval = evals(i)
-        val dt = children(i).dataType
-        s"$converter.apply(${eval.isNull} ? null : (${ctx.boxedType(dt)}) ${eval.value})"
-    }.mkString(",")
-    val callFunc = s"${ctx.boxedType(ctx.javaType(dataType))} $resultTerm = " +
-      s"(${ctx.boxedType(ctx.javaType(dataType))})${catalystConverterTerm}" +
-        s".apply($funcTerm.apply($funcArguments));"
+    val (converters, funcArguments) = converterTerms.zipWithIndex.map { case (converter, i) =>
+      val eval = evals(i)
+      val argTerm = ctx.freshName("arg")
+      val convert = s"Object $argTerm = ${eval.isNull} ? null : $converter.apply(${eval.value});"
+      (convert, argTerm)
+    }.unzip
 
-    evalCode + s"""
-      ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
-      Boolean ${ev.isNull};
+    val callFunc = s"${ctx.boxedType(dataType)} $resultTerm = " +
+      s"(${ctx.boxedType(dataType)})${catalystConverterTerm}" +
+        s".apply($funcTerm.apply(${funcArguments.mkString(", ")}));"
 
+    s"""
+      $evalCode
+      ${converters.mkString("\n")}
       $callFunc
 
-      ${ev.value} = $resultTerm;
-      ${ev.isNull} = $resultTerm == null;
+      boolean ${ev.isNull} = $resultTerm == null;
+      ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
+      if (!${ev.isNull}) {
+        ${ev.value} = $resultTerm;
+      }
     """
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/c119a34d/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index 8887dc6..5353fef 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -1144,9 +1144,13 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
 
     // passing null into the UDF that could handle it
     val boxedUDF = udf[java.lang.Integer, java.lang.Integer] {
-      (i: java.lang.Integer) => if (i == null) -10 else i * 2
+      (i: java.lang.Integer) => if (i == null) -10 else null
     }
-    checkAnswer(df.select(boxedUDF($"age")), Row(44) :: Row(-10) :: Nil)
+    checkAnswer(df.select(boxedUDF($"age")), Row(null) :: Row(-10) :: Nil)
+
+    sqlContext.udf.register("boxedUDF",
+      (i: java.lang.Integer) => (if (i == null) -10 else null): java.lang.Integer)
+    checkAnswer(sql("select boxedUDF(null), boxedUDF(-1)"), Row(-10, null) :: Nil)
 
     val primitiveUDF = udf((i: Int) => i * 2)
     checkAnswer(df.select(primitiveUDF($"age")), Row(44) :: Row(null) :: Nil)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org