You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2020/02/21 07:19:25 UTC

[spark] branch branch-3.0 updated: [SPARK-26580][SQL][ML][FOLLOW-UP] Throw exception when use untyped UDF by default

This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch branch-3.0
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.0 by this push:
     new bc30a07  [SPARK-26580][SQL][ML][FOLLOW-UP] Throw exception when use untyped UDF by default
bc30a07 is described below

commit bc30a07ce262840c99a752db4fbd3a423f652017
Author: yi.wu <yi...@databricks.com>
AuthorDate: Fri Feb 21 14:46:54 2020 +0800

    [SPARK-26580][SQL][ML][FOLLOW-UP] Throw exception when use untyped UDF by default
    
    ### What changes were proposed in this pull request?
    
    This PR proposes to throw exception by default when user use untyped UDF(a.k.a `org.apache.spark.sql.functions.udf(AnyRef, DataType)`).
    
    And user could still use it by setting `spark.sql.legacy.useUnTypedUdf.enabled` to `true`.
    
    ### Why are the changes needed?
    
    According to #23498, since Spark 3.0, the untyped UDF will return the default value of the Java type if the input value is null. For example, `val f = udf((x: Int) => x, IntegerType)`, `f($"x")` will  return 0 in Spark 3.0 but null in Spark 2.4. And the behavior change is introduced due to Spark3.0 is built with Scala 2.12 by default.
    
    As a result, this might change data silently and may cause correctness issue if user still expect `null` in some cases. Thus, we'd better to encourage user to use typed UDF to avoid this problem.
    
    ### Does this PR introduce any user-facing change?
    
    Yeah. User will hit exception now when use untyped UDF.
    
    ### How was this patch tested?
    
    Added test and updated some tests.
    
    Closes #27488 from Ngone51/spark_26580_followup.
    
    Lead-authored-by: yi.wu <yi...@databricks.com>
    Co-authored-by: wuyi <yi...@databricks.com>
    Signed-off-by: Wenchen Fan <we...@databricks.com>
    (cherry picked from commit 82ce4753aafafc590f0cd6ba47d5db20a2d6137b)
    Signed-off-by: Wenchen Fan <we...@databricks.com>
---
 docs/sql-migration-guide.md                        |  2 +-
 .../scala/org/apache/spark/ml/Transformer.scala    |  5 ++--
 .../scala/org/apache/spark/ml/feature/LSH.scala    | 11 ++++----
 .../scala/org/apache/spark/ml/fpm/FPGrowth.scala   | 11 +++++---
 .../org/apache/spark/ml/feature/LSHTest.scala      |  9 +++---
 project/MimaExcludes.scala                         |  3 ++
 .../org/apache/spark/sql/internal/SQLConf.scala    |  8 ++++++
 .../sql/expressions/UserDefinedFunction.scala      |  2 +-
 .../scala/org/apache/spark/sql/functions.scala     |  9 ++++++
 .../test/scala/org/apache/spark/sql/UDFSuite.scala | 33 ++++++++++++++--------
 10 files changed, 62 insertions(+), 31 deletions(-)

diff --git a/docs/sql-migration-guide.md b/docs/sql-migration-guide.md
index 9b74b45..7e52e69 100644
--- a/docs/sql-migration-guide.md
+++ b/docs/sql-migration-guide.md
@@ -63,7 +63,7 @@ license: |
 
   - Since Spark 3.0, JSON datasource and JSON function `schema_of_json` infer TimestampType from string values if they match to the pattern defined by the JSON option `timestampFormat`. Set JSON option `inferTimestamp` to `false` to disable such type inferring.
 
-  - In Spark version 2.4 and earlier, if `org.apache.spark.sql.functions.udf(Any, DataType)` gets a Scala closure with primitive-type argument, the returned UDF will return null if the input values is null. Since Spark 3.0, the UDF will return the default value of the Java type if the input value is null. For example, `val f = udf((x: Int) => x, IntegerType)`, `f($"x")` will return null in Spark 2.4 and earlier if column `x` is null, and return 0 in Spark 3.0. This behavior change is int [...]
+  - Since Spark 3.0, using `org.apache.spark.sql.functions.udf(AnyRef, DataType)` is not allowed by default. Set `spark.sql.legacy.allowUntypedScalaUDF` to true to keep using it. But please note that, in Spark version 2.4 and earlier, if `org.apache.spark.sql.functions.udf(AnyRef, DataType)` gets a Scala closure with primitive-type argument, the returned UDF will return null if the input values is null. However, since Spark 3.0, the UDF will return the default value of the Java type if t [...]
 
   - Since Spark 3.0, Proleptic Gregorian calendar is used in parsing, formatting, and converting dates and timestamps as well as in extracting sub-components like years, days and etc. Spark 3.0 uses Java 8 API classes from the java.time packages that based on ISO chronology (https://docs.oracle.com/javase/8/docs/api/java/time/chrono/IsoChronology.html). In Spark version 2.4 and earlier, those operations are performed by using the hybrid calendar (Julian + Gregorian, see https://docs.orac [...]
 
diff --git a/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala b/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala
index 7874fc2..1652131 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala
@@ -18,6 +18,7 @@
 package org.apache.spark.ml
 
 import scala.annotation.varargs
+import scala.reflect.runtime.universe.TypeTag
 
 import org.apache.spark.annotation.{DeveloperApi, Since}
 import org.apache.spark.internal.Logging
@@ -79,7 +80,7 @@ abstract class Transformer extends PipelineStage {
  * result as a new column.
  */
 @DeveloperApi
-abstract class UnaryTransformer[IN, OUT, T <: UnaryTransformer[IN, OUT, T]]
+abstract class UnaryTransformer[IN: TypeTag, OUT: TypeTag, T <: UnaryTransformer[IN, OUT, T]]
   extends Transformer with HasInputCol with HasOutputCol with Logging {
 
   /** @group setParam */
@@ -118,7 +119,7 @@ abstract class UnaryTransformer[IN, OUT, T <: UnaryTransformer[IN, OUT, T]]
 
   override def transform(dataset: Dataset[_]): DataFrame = {
     val outputSchema = transformSchema(dataset.schema, logging = true)
-    val transformUDF = udf(this.createTransformFunc, outputDataType)
+    val transformUDF = udf(this.createTransformFunc)
     dataset.withColumn($(outputCol), transformUDF(dataset($(inputCol))),
       outputSchema($(outputCol)).metadata)
   }
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/LSH.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/LSH.scala
index 0174101..6d5c7c5 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/LSH.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/LSH.scala
@@ -98,7 +98,7 @@ private[ml] abstract class LSHModel[T <: LSHModel[T]]
 
   override def transform(dataset: Dataset[_]): DataFrame = {
     transformSchema(dataset.schema, logging = true)
-    val transformUDF = udf(hashFunction(_: Vector), DataTypes.createArrayType(new VectorUDT))
+    val transformUDF = udf(hashFunction(_: Vector))
     dataset.withColumn($(outputCol), transformUDF(dataset($(inputCol))))
   }
 
@@ -128,14 +128,13 @@ private[ml] abstract class LSHModel[T <: LSHModel[T]]
       }
 
       // In the origin dataset, find the hash value that hash the same bucket with the key
-      val sameBucketWithKeyUDF = udf((x: Seq[Vector]) =>
-        sameBucket(x, keyHash), DataTypes.BooleanType)
+      val sameBucketWithKeyUDF = udf((x: Seq[Vector]) => sameBucket(x, keyHash))
 
       modelDataset.filter(sameBucketWithKeyUDF(col($(outputCol))))
     } else {
       // In the origin dataset, find the hash value that is closest to the key
       // Limit the use of hashDist since it's controversial
-      val hashDistUDF = udf((x: Seq[Vector]) => hashDistance(x, keyHash), DataTypes.DoubleType)
+      val hashDistUDF = udf((x: Seq[Vector]) => hashDistance(x, keyHash))
       val hashDistCol = hashDistUDF(col($(outputCol)))
       val modelDatasetWithDist = modelDataset.withColumn(distCol, hashDistCol)
 
@@ -172,7 +171,7 @@ private[ml] abstract class LSHModel[T <: LSHModel[T]]
     }
 
     // Get the top k nearest neighbor by their distance to the key
-    val keyDistUDF = udf((x: Vector) => keyDistance(x, key), DataTypes.DoubleType)
+    val keyDistUDF = udf((x: Vector) => keyDistance(x, key))
     val modelSubsetWithDistCol = modelSubset.withColumn(distCol, keyDistUDF(col($(inputCol))))
     modelSubsetWithDistCol.sort(distCol).limit(numNearestNeighbors)
   }
@@ -290,7 +289,7 @@ private[ml] abstract class LSHModel[T <: LSHModel[T]]
       .drop(explodeCols: _*).distinct()
 
     // Add a new column to store the distance of the two rows.
-    val distUDF = udf((x: Vector, y: Vector) => keyDistance(x, y), DataTypes.DoubleType)
+    val distUDF = udf((x: Vector, y: Vector) => keyDistance(x, y))
     val joinedDatasetWithDist = joinedDataset.select(col("*"),
       distUDF(col(s"$leftColName.${$(inputCol)}"), col(s"$rightColName.${$(inputCol)}")).as(distCol)
     )
diff --git a/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala b/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala
index 4d001c1..e50d425 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala
@@ -29,10 +29,10 @@ import org.apache.spark.ml.param._
 import org.apache.spark.ml.param.shared.HasPredictionCol
 import org.apache.spark.ml.util._
 import org.apache.spark.ml.util.Instrumentation.instrumented
-import org.apache.spark.mllib.fpm.{AssociationRules => MLlibAssociationRules,
-  FPGrowth => MLlibFPGrowth}
+import org.apache.spark.mllib.fpm.{AssociationRules => MLlibAssociationRules, FPGrowth => MLlibFPGrowth}
 import org.apache.spark.mllib.fpm.FPGrowth.FreqItemset
 import org.apache.spark.sql._
+import org.apache.spark.sql.expressions.SparkUserDefinedFunction
 import org.apache.spark.sql.functions._
 import org.apache.spark.sql.types._
 import org.apache.spark.storage.StorageLevel
@@ -286,14 +286,17 @@ class FPGrowthModel private[ml] (
 
     val dt = dataset.schema($(itemsCol)).dataType
     // For each rule, examine the input items and summarize the consequents
-    val predictUDF = udf((items: Seq[Any]) => {
+    val predictUDF = SparkUserDefinedFunction((items: Seq[Any]) => {
       if (items != null) {
         val itemset = items.toSet
         brRules.value.filter(_._1.forall(itemset.contains))
           .flatMap(_._2.filter(!itemset.contains(_))).distinct
       } else {
         Seq.empty
-      }}, dt)
+      }},
+      dt,
+      Nil
+    )
     dataset.withColumn($(predictionCol), predictUDF(col($(itemsCol))))
   }
 
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/LSHTest.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/LSHTest.scala
index 76a4acd..1d052fb 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/LSHTest.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/LSHTest.scala
@@ -76,9 +76,8 @@ private[ml] object LSHTest {
 
     // Perform a cross join and label each pair of same_bucket and distance
     val pairs = transformedData.as("a").crossJoin(transformedData.as("b"))
-    val distUDF = udf((x: Vector, y: Vector) => model.keyDistance(x, y), DataTypes.DoubleType)
-    val sameBucket = udf((x: Seq[Vector], y: Seq[Vector]) => model.hashDistance(x, y) == 0.0,
-      DataTypes.BooleanType)
+    val distUDF = udf((x: Vector, y: Vector) => model.keyDistance(x, y))
+    val sameBucket = udf((x: Seq[Vector], y: Seq[Vector]) => model.hashDistance(x, y) == 0.0)
     val result = pairs
       .withColumn("same_bucket", sameBucket(col(s"a.$outputCol"), col(s"b.$outputCol")))
       .withColumn("distance", distUDF(col(s"a.$inputCol"), col(s"b.$inputCol")))
@@ -110,7 +109,7 @@ private[ml] object LSHTest {
     val model = lsh.fit(dataset)
 
     // Compute expected
-    val distUDF = udf((x: Vector) => model.keyDistance(x, key), DataTypes.DoubleType)
+    val distUDF = udf((x: Vector) => model.keyDistance(x, key))
     val expected = dataset.sort(distUDF(col(model.getInputCol))).limit(k)
 
     // Compute actual
@@ -148,7 +147,7 @@ private[ml] object LSHTest {
     val inputCol = model.getInputCol
 
     // Compute expected
-    val distUDF = udf((x: Vector, y: Vector) => model.keyDistance(x, y), DataTypes.DoubleType)
+    val distUDF = udf((x: Vector, y: Vector) => model.keyDistance(x, y))
     val expected = datasetA.as("a").crossJoin(datasetB.as("b"))
       .filter(distUDF(col(s"a.$inputCol"), col(s"b.$inputCol")) < threshold)
 
diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala
index 65ffa22..2a72284 100644
--- a/project/MimaExcludes.scala
+++ b/project/MimaExcludes.scala
@@ -74,6 +74,9 @@ object MimaExcludes {
     ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.mllib.clustering.KMeans.getRuns"),
     ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.mllib.clustering.KMeans.setRuns"),
 
+    // [SPARK-26580][SQL][ML][FOLLOW-UP] Throw exception when use untyped UDF by default
+    ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.UnaryTransformer.this"),
+
     // [SPARK-27090][CORE] Removing old LEGACY_DRIVER_IDENTIFIER ("<driver>")
     ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.SparkContext.LEGACY_DRIVER_IDENTIFIER"),
     // [SPARK-25838] Remove formatVersion from Saveable
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
index bcec22f..c73d8b6 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
@@ -2017,6 +2017,14 @@ object SQLConf {
       .booleanConf
       .createWithDefault(false)
 
+  val LEGACY_ALLOW_UNTYPED_SCALA_UDF =
+    buildConf("spark.sql.legacy.allowUntypedScalaUDF")
+      .internal()
+      .doc("When set to true, user is allowed to use org.apache.spark.sql.functions." +
+        "udf(f: AnyRef, dataType: DataType). Otherwise, exception will be throw.")
+      .booleanConf
+      .createWithDefault(false)
+
   val TRUNCATE_TABLE_IGNORE_PERMISSION_ACL =
     buildConf("spark.sql.truncateTable.ignorePermissionAcl.enabled")
       .internal()
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala
index 85b2cd3..c50168c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala
@@ -90,7 +90,7 @@ sealed abstract class UserDefinedFunction {
   def asNondeterministic(): UserDefinedFunction
 }
 
-private[sql] case class SparkUserDefinedFunction(
+private[spark] case class SparkUserDefinedFunction(
     f: AnyRef,
     dataType: DataType,
     inputSchemas: Seq[Option[ScalaReflection.Schema]],
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index 2d5504a..c60df14 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -4733,6 +4733,15 @@ object functions {
    * @since 2.0.0
    */
   def udf(f: AnyRef, dataType: DataType): UserDefinedFunction = {
+    if (!SQLConf.get.getConf(SQLConf.LEGACY_ALLOW_UNTYPED_SCALA_UDF)) {
+      val errorMsg = "You're using untyped Scala UDF, which does not have the input type " +
+        "information. Spark may blindly pass null to the Scala closure with primitive-type " +
+        "argument, and the closure will see the default value of the Java type for the null " +
+        "argument, e.g. `udf((x: Int) => x, IntegerType)`, the result is 0 for null input. " +
+        "You could use other typed Scala UDF APIs to avoid this problem, or set " +
+        s"${SQLConf.LEGACY_ALLOW_UNTYPED_SCALA_UDF.key} to true and use this API with caution."
+      throw new AnalysisException(errorMsg)
+    }
     SparkUserDefinedFunction(f, dataType, inputSchemas = Nil)
   }
 
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala
index cc39955..cbe2e91 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala
@@ -134,10 +134,12 @@ class UDFSuite extends QueryTest with SharedSparkSession {
     assert(df1.logicalPlan.asInstanceOf[Project].projectList.forall(!_.deterministic))
     assert(df1.head().getDouble(0) >= 0.0)
 
-    val bar = udf(() => Math.random(), DataTypes.DoubleType).asNondeterministic()
-    val df2 = testData.select(bar())
-    assert(df2.logicalPlan.asInstanceOf[Project].projectList.forall(!_.deterministic))
-    assert(df2.head().getDouble(0) >= 0.0)
+    withSQLConf(SQLConf.LEGACY_ALLOW_UNTYPED_SCALA_UDF.key -> "true") {
+      val bar = udf(() => Math.random(), DataTypes.DoubleType).asNondeterministic()
+      val df2 = testData.select(bar())
+      assert(df2.logicalPlan.asInstanceOf[Project].projectList.forall(!_.deterministic))
+      assert(df2.head().getDouble(0) >= 0.0)
+    }
 
     val javaUdf = udf(new UDF0[Double] {
       override def call(): Double = Math.random()
@@ -441,16 +443,23 @@ class UDFSuite extends QueryTest with SharedSparkSession {
   }
 
   test("SPARK-25044 Verify null input handling for primitive types - with udf(Any, DataType)") {
-    val f = udf((x: Int) => x, IntegerType)
-    checkAnswer(
-      Seq(Integer.valueOf(1), null).toDF("x").select(f($"x")),
-      Row(1) :: Row(0) :: Nil)
+    withSQLConf(SQLConf.LEGACY_ALLOW_UNTYPED_SCALA_UDF.key -> "true") {
+      val f = udf((x: Int) => x, IntegerType)
+      checkAnswer(
+        Seq(Integer.valueOf(1), null).toDF("x").select(f($"x")),
+        Row(1) :: Row(0) :: Nil)
+
+      val f2 = udf((x: Double) => x, DoubleType)
+      checkAnswer(
+        Seq(java.lang.Double.valueOf(1.1), null).toDF("x").select(f2($"x")),
+        Row(1.1) :: Row(0.0) :: Nil)
+    }
 
-    val f2 = udf((x: Double) => x, DoubleType)
-    checkAnswer(
-      Seq(java.lang.Double.valueOf(1.1), null).toDF("x").select(f2($"x")),
-      Row(1.1) :: Row(0.0) :: Nil)
+  }
 
+  test("use untyped Scala UDF should fail by default") {
+    val e = intercept[AnalysisException](udf((x: Int) => x, IntegerType))
+    assert(e.getMessage.contains("You're using untyped Scala UDF"))
   }
 
   test("SPARK-26308: udf with decimal") {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org