You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by do...@apache.org on 2020/07/11 13:47:08 UTC

[spark] branch branch-2.4 updated: [SPARK-32238][SQL] Use Utils.getSimpleName to avoid hitting Malformed class name in ScalaUDF

This is an automated email from the ASF dual-hosted git repository.

dongjoon pushed a commit to branch branch-2.4
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-2.4 by this push:
     new 4c82ae8  [SPARK-32238][SQL] Use Utils.getSimpleName to avoid hitting Malformed class name in ScalaUDF
4c82ae8 is described below

commit 4c82ae8a1e9e6f63d142931d9edbce430eac6e39
Author: yi.wu <yi...@databricks.com>
AuthorDate: Sat Jul 11 06:27:56 2020 -0700

    [SPARK-32238][SQL] Use Utils.getSimpleName to avoid hitting Malformed class name in ScalaUDF
    
    This PR proposes to use `Utils.getSimpleName(function)` instead of `function.getClass.getSimpleName` to get the class name.
    
    For some functions(see the demo below),  using `function.getClass.getSimpleName` can hit "Malformed class name" error.
    
    Yes.  For the demo,
    
    ```scala
     object MalformedClassObject extends Serializable {
        class MalformedNonPrimitiveFunction extends (String => Int) with Serializable {
          override def apply(v1: String): Int = v1.toInt / 0
        }
      }
     OuterScopes.addOuterScope(MalformedClassObject)
     val f = new MalformedClassObject.MalformedNonPrimitiveFunction()
     Seq("20").toDF("col").select(udf(f).apply(Column("col"))).collect()
    ```
    
    Before this PR, user can only see the error about "Malformed class name":
    
    ```scala
    An exception or error caused a run to abort: Malformed class name
    java.lang.InternalError: Malformed class name
    	at java.lang.Class.getSimpleName(Class.java:1330)
    	at org.apache.spark.sql.catalyst.expressions.ScalaUDF.udfErrorMessage$lzycompute(ScalaUDF.scala:1157)
    	at org.apache.spark.sql.catalyst.expressions.ScalaUDF.udfErrorMessage(ScalaUDF.scala:1155)
    	at org.apache.spark.sql.catalyst.expressions.ScalaUDF.doGenCode(ScalaUDF.scala:1077)
    	at org.apache.spark.sql.catalyst.expressions.Expression.$anonfun$genCode$3(Expression.scala:147)
    	at scala.Option.getOrElse(Option.scala:189)
    	at org.apache.spark.sql.catalyst.expressions.Expression.genCode(Expression.scala:142)
    	at org.apache.spark.sql.catalyst.expressions.Alias.genCode(namedExpressions.scala:160)
    	at org.apache.spark.sql.execution.ProjectExec.$anonfun$doConsume$1(basicPhysicalOperators.scala:69)
            ...
    ```
    
    After this PR, user can see the real root cause of the udf failure:
    
    ```scala
    org.apache.spark.SparkException: Failed to execute user defined function(UDFSuite$MalformedClassObject$MalformedNonPrimitiveFunction: (string) => int)
    	at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage1.processNext(Unknown Source)
    	at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43)
    	at org.apache.spark.sql.execution.WholeStageCodegenExec$$anon$1.hasNext(WholeStageCodegenExec.scala:753)
    	at org.apache.spark.sql.execution.SparkPlan.$anonfun$getByteArrayRdd$1(SparkPlan.scala:340)
    	at org.apache.spark.rdd.RDD.$anonfun$mapPartitionsInternal$2(RDD.scala:898)
    	at org.apache.spark.rdd.RDD.$anonfun$mapPartitionsInternal$2$adapted(RDD.scala:898)
    	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
    	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:373)
    	at org.apache.spark.rdd.RDD.iterator(RDD.scala:337)
    	at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:90)
    	at org.apache.spark.scheduler.Task.run(Task.scala:127)
    	at org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$3(Executor.scala:464)
    	at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1377)
    	at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:467)
    	at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)
    	at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)
    	at java.lang.Thread.run(Thread.java:748)
    Caused by: java.lang.ArithmeticException: / by zero
    	at org.apache.spark.sql.UDFSuite$MalformedClassObject$MalformedNonPrimitiveFunction.apply(UDFSuite.scala:677)
    	at org.apache.spark.sql.UDFSuite$MalformedClassObject$MalformedNonPrimitiveFunction.apply(UDFSuite.scala:676)
    	... 17 more
    
    ```
    
    Added a test.
    
    Closes #29050 from Ngone51/fix-malformed-udf.
    
    Authored-by: yi.wu <yi...@databricks.com>
    Signed-off-by: Dongjoon Hyun <do...@apache.org>
    (cherry picked from commit 0c9196e5493628d343ef67bb9e83d0c95ff3943a)
    Signed-off-by: Dongjoon Hyun <do...@apache.org>
---
 .../spark/sql/catalyst/expressions/ScalaUDF.scala  |  5 ++--
 .../test/scala/org/apache/spark/sql/UDFSuite.scala | 30 ++++++++++++++++++++++
 2 files changed, 33 insertions(+), 2 deletions(-)

diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala
index fae90ca..fb69a6e 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala
@@ -21,7 +21,8 @@ import org.apache.spark.SparkException
 import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow, ScalaReflection}
 import org.apache.spark.sql.catalyst.expressions.codegen._
 import org.apache.spark.sql.catalyst.expressions.codegen.Block._
-import org.apache.spark.sql.types.DataType
+import org.apache.spark.sql.types.{AbstractDataType, AnyDataType, DataType, UserDefinedType}
+import org.apache.spark.util.Utils
 
 /**
  * User-defined function.
@@ -1052,7 +1053,7 @@ case class ScalaUDF(
   private[this] val resultConverter = CatalystTypeConverters.createToCatalystConverter(dataType)
 
   lazy val udfErrorMessage = {
-    val funcCls = function.getClass.getSimpleName
+    val funcCls = Utils.getSimpleName(function.getClass)
     val inputTypes = children.map(_.dataType.catalogString).mkString(", ")
     val outputType = dataType.catalogString
     s"Failed to execute user defined function($funcCls: ($inputTypes) => $outputType)"
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala
index f8ed21b..595a6a1 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala
@@ -17,7 +17,11 @@
 
 package org.apache.spark.sql
 
+import java.math.BigDecimal
+
+import org.apache.spark.SparkException
 import org.apache.spark.sql.api.java._
+import org.apache.spark.sql.catalyst.encoders.OuterScopes
 import org.apache.spark.sql.catalyst.plans.logical.Project
 import org.apache.spark.sql.execution.QueryExecution
 import org.apache.spark.sql.execution.columnar.InMemoryRelation
@@ -417,4 +421,30 @@ class UDFSuite extends QueryTest with SharedSQLContext {
       checkAnswer(df, Seq(Row("null1x"), Row(null), Row("N3null")))
     }
   }
+
+  object MalformedClassObject extends Serializable {
+    class MalformedNonPrimitiveFunction extends (String => Int) with Serializable {
+      override def apply(v1: String): Int = v1.toInt / 0
+    }
+
+    class MalformedPrimitiveFunction extends (Int => Int) with Serializable {
+      override def apply(v1: Int): Int = v1 / 0
+    }
+  }
+
+  test("SPARK-32238: Use Utils.getSimpleName to avoid hitting Malformed class name") {
+    OuterScopes.addOuterScope(MalformedClassObject)
+    val f1 = new MalformedClassObject.MalformedNonPrimitiveFunction()
+    val f2 = new MalformedClassObject.MalformedPrimitiveFunction()
+
+    val e1 = intercept[SparkException] {
+      Seq("20").toDF("col").select(udf(f1).apply(Column("col"))).collect()
+    }
+    assert(e1.getMessage.contains("UDFSuite$MalformedClassObject$MalformedNonPrimitiveFunction"))
+
+    val e2 = intercept[SparkException] {
+      Seq(20).toDF("col").select(udf(f2).apply(Column("col"))).collect()
+    }
+    assert(e2.getMessage.contains("UDFSuite$MalformedClassObject$MalformedPrimitiveFunction"))
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org