You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by do...@apache.org on 2020/07/11 13:35:40 UTC

[spark] branch branch-3.0 updated: [SPARK-32238][SQL] Use Utils.getSimpleName to avoid hitting Malformed class name in ScalaUDF

This is an automated email from the ASF dual-hosted git repository.

dongjoon pushed a commit to branch branch-3.0
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.0 by this push:
     new a0e5025  [SPARK-32238][SQL] Use Utils.getSimpleName to avoid hitting Malformed class name in ScalaUDF
a0e5025 is described below

commit a0e5025fb8c88e53414a24bbe983c3848d71d173
Author: yi.wu <yi...@databricks.com>
AuthorDate: Sat Jul 11 06:27:56 2020 -0700

    [SPARK-32238][SQL] Use Utils.getSimpleName to avoid hitting Malformed class name in ScalaUDF
    
    This PR proposes to use `Utils.getSimpleName(function)` instead of `function.getClass.getSimpleName` to get the class name.
    
    For some functions(see the demo below),  using `function.getClass.getSimpleName` can hit "Malformed class name" error.
    
    Yes.  For the demo,
    
    ```scala
     object MalformedClassObject extends Serializable {
        class MalformedNonPrimitiveFunction extends (String => Int) with Serializable {
          override def apply(v1: String): Int = v1.toInt / 0
        }
      }
     OuterScopes.addOuterScope(MalformedClassObject)
     val f = new MalformedClassObject.MalformedNonPrimitiveFunction()
     Seq("20").toDF("col").select(udf(f).apply(Column("col"))).collect()
    ```
    
    Before this PR, user can only see the error about "Malformed class name":
    
    ```scala
    An exception or error caused a run to abort: Malformed class name
    java.lang.InternalError: Malformed class name
    	at java.lang.Class.getSimpleName(Class.java:1330)
    	at org.apache.spark.sql.catalyst.expressions.ScalaUDF.udfErrorMessage$lzycompute(ScalaUDF.scala:1157)
    	at org.apache.spark.sql.catalyst.expressions.ScalaUDF.udfErrorMessage(ScalaUDF.scala:1155)
    	at org.apache.spark.sql.catalyst.expressions.ScalaUDF.doGenCode(ScalaUDF.scala:1077)
    	at org.apache.spark.sql.catalyst.expressions.Expression.$anonfun$genCode$3(Expression.scala:147)
    	at scala.Option.getOrElse(Option.scala:189)
    	at org.apache.spark.sql.catalyst.expressions.Expression.genCode(Expression.scala:142)
    	at org.apache.spark.sql.catalyst.expressions.Alias.genCode(namedExpressions.scala:160)
    	at org.apache.spark.sql.execution.ProjectExec.$anonfun$doConsume$1(basicPhysicalOperators.scala:69)
            ...
    ```
    
    After this PR, user can see the real root cause of the udf failure:
    
    ```scala
    org.apache.spark.SparkException: Failed to execute user defined function(UDFSuite$MalformedClassObject$MalformedNonPrimitiveFunction: (string) => int)
    	at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage1.processNext(Unknown Source)
    	at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43)
    	at org.apache.spark.sql.execution.WholeStageCodegenExec$$anon$1.hasNext(WholeStageCodegenExec.scala:753)
    	at org.apache.spark.sql.execution.SparkPlan.$anonfun$getByteArrayRdd$1(SparkPlan.scala:340)
    	at org.apache.spark.rdd.RDD.$anonfun$mapPartitionsInternal$2(RDD.scala:898)
    	at org.apache.spark.rdd.RDD.$anonfun$mapPartitionsInternal$2$adapted(RDD.scala:898)
    	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
    	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:373)
    	at org.apache.spark.rdd.RDD.iterator(RDD.scala:337)
    	at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:90)
    	at org.apache.spark.scheduler.Task.run(Task.scala:127)
    	at org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$3(Executor.scala:464)
    	at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1377)
    	at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:467)
    	at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)
    	at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)
    	at java.lang.Thread.run(Thread.java:748)
    Caused by: java.lang.ArithmeticException: / by zero
    	at org.apache.spark.sql.UDFSuite$MalformedClassObject$MalformedNonPrimitiveFunction.apply(UDFSuite.scala:677)
    	at org.apache.spark.sql.UDFSuite$MalformedClassObject$MalformedNonPrimitiveFunction.apply(UDFSuite.scala:676)
    	... 17 more
    
    ```
    
    Added a test.
    
    Closes #29050 from Ngone51/fix-malformed-udf.
    
    Authored-by: yi.wu <yi...@databricks.com>
    Signed-off-by: Dongjoon Hyun <do...@apache.org>
    (cherry picked from commit 0c9196e5493628d343ef67bb9e83d0c95ff3943a)
    Signed-off-by: Dongjoon Hyun <do...@apache.org>
---
 .../spark/sql/catalyst/expressions/ScalaUDF.scala  |  5 ++--
 .../test/scala/org/apache/spark/sql/UDFSuite.scala | 28 ++++++++++++++++++++++
 2 files changed, 31 insertions(+), 2 deletions(-)

diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala
index e80f03e..58a9f68 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala
@@ -24,7 +24,8 @@ import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
 import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
 import org.apache.spark.sql.catalyst.expressions.codegen._
 import org.apache.spark.sql.catalyst.expressions.codegen.Block._
-import org.apache.spark.sql.types.{AbstractDataType, AnyDataType, DataType}
+import org.apache.spark.sql.types.{AbstractDataType, AnyDataType, DataType, UserDefinedType}
+import org.apache.spark.util.Utils
 
 /**
  * User-defined function.
@@ -1115,7 +1116,7 @@ case class ScalaUDF(
   private[this] val resultConverter = CatalystTypeConverters.createToCatalystConverter(dataType)
 
   lazy val udfErrorMessage = {
-    val funcCls = function.getClass.getSimpleName
+    val funcCls = Utils.getSimpleName(function.getClass)
     val inputTypes = children.map(_.dataType.catalogString).mkString(", ")
     val outputType = dataType.catalogString
     s"Failed to execute user defined function($funcCls: ($inputTypes) => $outputType)"
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala
index e2747d7..91e9f1d 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala
@@ -19,7 +19,9 @@ package org.apache.spark.sql
 
 import java.math.BigDecimal
 
+import org.apache.spark.SparkException
 import org.apache.spark.sql.api.java._
+import org.apache.spark.sql.catalyst.encoders.OuterScopes
 import org.apache.spark.sql.catalyst.plans.logical.Project
 import org.apache.spark.sql.execution.{QueryExecution, SimpleMode}
 import org.apache.spark.sql.execution.columnar.InMemoryRelation
@@ -581,4 +583,30 @@ class UDFSuite extends QueryTest with SharedSparkSession {
       .toDF("col1", "col2")
     checkAnswer(df.select(myUdf(Column("col1"), Column("col2"))), Row(2020) :: Nil)
   }
+
+  object MalformedClassObject extends Serializable {
+    class MalformedNonPrimitiveFunction extends (String => Int) with Serializable {
+      override def apply(v1: String): Int = v1.toInt / 0
+    }
+
+    class MalformedPrimitiveFunction extends (Int => Int) with Serializable {
+      override def apply(v1: Int): Int = v1 / 0
+    }
+  }
+
+  test("SPARK-32238: Use Utils.getSimpleName to avoid hitting Malformed class name") {
+    OuterScopes.addOuterScope(MalformedClassObject)
+    val f1 = new MalformedClassObject.MalformedNonPrimitiveFunction()
+    val f2 = new MalformedClassObject.MalformedPrimitiveFunction()
+
+    val e1 = intercept[SparkException] {
+      Seq("20").toDF("col").select(udf(f1).apply(Column("col"))).collect()
+    }
+    assert(e1.getMessage.contains("UDFSuite$MalformedClassObject$MalformedNonPrimitiveFunction"))
+
+    val e2 = intercept[SparkException] {
+      Seq(20).toDF("col").select(udf(f2).apply(Column("col"))).collect()
+    }
+    assert(e2.getMessage.contains("UDFSuite$MalformedClassObject$MalformedPrimitiveFunction"))
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org