You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2016/09/06 02:36:09 UTC

spark git commit: [SPARK-17279][SQL] better error message for exceptions during ScalaUDF execution

Repository: spark
Updated Branches:
  refs/heads/master 6d86403d8 -> 8d08f43d0


[SPARK-17279][SQL] better error message for exceptions during ScalaUDF execution

## What changes were proposed in this pull request?

If `ScalaUDF` throws exceptions during executing user code, sometimes it's hard for users to figure out what's wrong, especially when they use Spark shell. An example
```
org.apache.spark.SparkException: Job aborted due to stage failure: Task 12 in stage 325.0 failed 4 times, most recent failure: Lost task 12.3 in stage 325.0 (TID 35622, 10.0.207.202): java.lang.NullPointerException
	at line8414e872fb8b42aba390efc153d1611a12.$read$$iwC$$iwC$$iwC$$iwC$$anonfun$2.apply(<console>:40)
	at line8414e872fb8b42aba390efc153d1611a12.$read$$iwC$$iwC$$iwC$$iwC$$anonfun$2.apply(<console>:40)
	at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIterator.processNext(Unknown Source)
...
```
We should catch these exceptions and rethrow them with better error message, to say that the exception is happened in scala udf.

This PR also does some clean up for `ScalaUDF` and add a unit test suite for it.

## How was this patch tested?

the new test suite

Author: Wenchen Fan <we...@databricks.com>

Closes #14850 from cloud-fan/npe.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/8d08f43d
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/8d08f43d
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/8d08f43d

Branch: refs/heads/master
Commit: 8d08f43d09157b98e559c0be6ce6fd571a35e0d1
Parents: 6d86403
Author: Wenchen Fan <we...@databricks.com>
Authored: Tue Sep 6 10:36:00 2016 +0800
Committer: Wenchen Fan <we...@databricks.com>
Committed: Tue Sep 6 10:36:00 2016 +0800

----------------------------------------------------------------------
 .../spark/ml/recommendation/ALSSuite.scala      | 16 +++----
 .../sql/catalyst/expressions/ScalaUDF.scala     | 44 ++++++++++++------
 .../catalyst/expressions/ScalaUDFSuite.scala    | 48 ++++++++++++++++++++
 3 files changed, 86 insertions(+), 22 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/8d08f43d/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala
index e8ed50a..d0aa2cd 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala
@@ -510,18 +510,18 @@ class ALSSuite
       (1, 1L, 1d, 0, 0L, 0d, 5.0)
     ).toDF("user", "user_big", "user_small", "item", "item_big", "item_small", "rating")
     withClue("fit should fail when ids exceed integer range. ") {
-      assert(intercept[IllegalArgumentException] {
+      assert(intercept[SparkException] {
         als.fit(df.select(df("user_big").as("user"), df("item"), df("rating")))
-      }.getMessage.contains("was out of Integer range"))
-      assert(intercept[IllegalArgumentException] {
+      }.getCause.getMessage.contains("was out of Integer range"))
+      assert(intercept[SparkException] {
         als.fit(df.select(df("user_small").as("user"), df("item"), df("rating")))
-      }.getMessage.contains("was out of Integer range"))
-      assert(intercept[IllegalArgumentException] {
+      }.getCause.getMessage.contains("was out of Integer range"))
+      assert(intercept[SparkException] {
         als.fit(df.select(df("item_big").as("item"), df("user"), df("rating")))
-      }.getMessage.contains("was out of Integer range"))
-      assert(intercept[IllegalArgumentException] {
+      }.getCause.getMessage.contains("was out of Integer range"))
+      assert(intercept[SparkException] {
         als.fit(df.select(df("item_small").as("item"), df("user"), df("rating")))
-      }.getMessage.contains("was out of Integer range"))
+      }.getCause.getMessage.contains("was out of Integer range"))
     }
     withClue("transform should fail when ids exceed integer range. ") {
       val model = als.fit(df)

http://git-wip-us.apache.org/repos/asf/spark/blob/8d08f43d/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala
index 2139064..6cfdea9 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala
@@ -17,6 +17,7 @@
 
 package org.apache.spark.sql.catalyst.expressions
 
+import org.apache.spark.SparkException
 import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
 import org.apache.spark.sql.catalyst.expressions.codegen._
 import org.apache.spark.sql.types.DataType
@@ -994,20 +995,15 @@ case class ScalaUDF(
       ctx: CodegenContext,
       ev: ExprCode): ExprCode = {
 
-    ctx.references += this
-
-    val scalaUDFClassName = classOf[ScalaUDF].getName
+    val scalaUDF = ctx.addReferenceObj("scalaUDF", this)
     val converterClassName = classOf[Any => Any].getName
     val typeConvertersClassName = CatalystTypeConverters.getClass.getName + ".MODULE$"
-    val expressionClassName = classOf[Expression].getName
 
     // Generate codes used to convert the returned value of user-defined functions to Catalyst type
     val catalystConverterTerm = ctx.freshName("catalystConverter")
-    val catalystConverterTermIdx = ctx.references.size - 1
     ctx.addMutableState(converterClassName, catalystConverterTerm,
       s"this.$catalystConverterTerm = ($converterClassName)$typeConvertersClassName" +
-        s".createToCatalystConverter((($scalaUDFClassName)references" +
-          s"[$catalystConverterTermIdx]).dataType());")
+        s".createToCatalystConverter($scalaUDF.dataType());")
 
     val resultTerm = ctx.freshName("result")
 
@@ -1019,10 +1015,8 @@ case class ScalaUDF(
     val funcClassName = s"scala.Function${children.size}"
 
     val funcTerm = ctx.freshName("udf")
-    val funcExpressionIdx = ctx.references.size - 1
     ctx.addMutableState(funcClassName, funcTerm,
-      s"this.$funcTerm = ($funcClassName)((($scalaUDFClassName)references" +
-        s"[$funcExpressionIdx]).userDefinedFunc());")
+      s"this.$funcTerm = ($funcClassName)$scalaUDF.userDefinedFunc();")
 
     // codegen for children expressions
     val evals = children.map(_.genCode(ctx))
@@ -1039,9 +1033,16 @@ case class ScalaUDF(
       (convert, argTerm)
     }.unzip
 
-    val callFunc = s"${ctx.boxedType(dataType)} $resultTerm = " +
-      s"(${ctx.boxedType(dataType)})${catalystConverterTerm}" +
-        s".apply($funcTerm.apply(${funcArguments.mkString(", ")}));"
+    val getFuncResult = s"$funcTerm.apply(${funcArguments.mkString(", ")})"
+    val callFunc =
+      s"""
+         ${ctx.boxedType(dataType)} $resultTerm = null;
+         try {
+           $resultTerm = (${ctx.boxedType(dataType)})$catalystConverterTerm.apply($getFuncResult);
+         } catch (Exception e) {
+           throw new org.apache.spark.SparkException($scalaUDF.udfErrorMessage(), e);
+         }
+       """
 
     ev.copy(code = s"""
       $evalCode
@@ -1057,5 +1058,20 @@ case class ScalaUDF(
 
   private[this] val converter = CatalystTypeConverters.createToCatalystConverter(dataType)
 
-  override def eval(input: InternalRow): Any = converter(f(input))
+  lazy val udfErrorMessage = {
+    val funcCls = function.getClass.getSimpleName
+    val inputTypes = children.map(_.dataType.simpleString).mkString(", ")
+    s"Failed to execute user defined function($funcCls: ($inputTypes) => ${dataType.simpleString})"
+  }
+
+  override def eval(input: InternalRow): Any = {
+    val result = try {
+      f(input)
+    } catch {
+      case e: Exception =>
+        throw new SparkException(udfErrorMessage, e)
+    }
+
+    converter(result)
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/8d08f43d/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDFSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDFSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDFSuite.scala
new file mode 100644
index 0000000..7e45028
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDFSuite.scala
@@ -0,0 +1,48 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions
+
+import org.apache.spark.{SparkException, SparkFunSuite}
+import org.apache.spark.sql.types.{IntegerType, StringType}
+
+class ScalaUDFSuite extends SparkFunSuite with ExpressionEvalHelper {
+
+  test("basic") {
+    val intUdf = ScalaUDF((i: Int) => i + 1, IntegerType, Literal(1) :: Nil)
+    checkEvaluation(intUdf, 2)
+
+    val stringUdf = ScalaUDF((s: String) => s + "x", StringType, Literal("a") :: Nil)
+    checkEvaluation(stringUdf, "ax")
+  }
+
+  test("better error message for NPE") {
+    val udf = ScalaUDF(
+      (s: String) => s.toLowerCase,
+      StringType,
+      Literal.create(null, StringType) :: Nil)
+
+    val e1 = intercept[SparkException](udf.eval())
+    assert(e1.getMessage.contains("Failed to execute user defined function"))
+
+    val e2 = intercept[SparkException] {
+      checkEvalutionWithUnsafeProjection(udf, null)
+    }
+    assert(e2.getMessage.contains("Failed to execute user defined function"))
+  }
+
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org