You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by do...@apache.org on 2020/07/14 19:22:49 UTC

[spark] branch master updated: [SPARK-32307][SQL] ScalaUDF's canonicalized expression should exclude inputEncoders

This is an automated email from the ASF dual-hosted git repository.

dongjoon pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new a47b69a  [SPARK-32307][SQL] ScalaUDF's canonicalized expression should exclude inputEncoders
a47b69a is described below

commit a47b69a88a271e423271709ee491e2de57c5581b
Author: yi.wu <yi...@databricks.com>
AuthorDate: Tue Jul 14 12:19:01 2020 -0700

    [SPARK-32307][SQL] ScalaUDF's canonicalized expression should exclude inputEncoders
    
    ### What changes were proposed in this pull request?
    
    Override `canonicalized` to empty the `inputEncoders` for the canonicalized `ScalaUDF`.
    
    ### Why are the changes needed?
    
    The following fails on `branch-3.0` currently, not on Apache Spark 3.0.0 release.
    
    ```scala
    spark.udf.register("key", udf((m: Map[String, String]) => m.keys.head.toInt))
    Seq(Map("1" -> "one", "2" -> "two")).toDF("a").createOrReplaceTempView("t")
    checkAnswer(sql("SELECT key(a) AS k FROM t GROUP BY key(a)"), Row(1) :: Nil)
    
    [info]   org.apache.spark.sql.AnalysisException: expression 't.`a`' is neither present in the group by, nor is it an aggregate function. Add to group by or wrap in first() (or first_value) if you don't care which value you get.;;
    [info] Aggregate [UDF(a#6)], [UDF(a#6) AS k#8]
    [info] +- SubqueryAlias t
    [info]    +- Project [value#3 AS a#6]
    [info]       +- LocalRelation [value#3]
    [info]   at org.apache.spark.sql.catalyst.analysis.CheckAnalysis.failAnalysis(CheckAnalysis.scala:49)
    [info]   at org.apache.spark.sql.catalyst.analysis.CheckAnalysis.failAnalysis$(CheckAnalysis.scala:48)
    [info]   at org.apache.spark.sql.catalyst.analysis.Analyzer.failAnalysis(Analyzer.scala:130)
    [info]   at org.apache.spark.sql.catalyst.analysis.CheckAnalysis.checkValidAggregateExpression$1(CheckAnalysis.scala:257)
    [info]   at org.apache.spark.sql.catalyst.analysis.CheckAnalysis.$anonfun$checkAnalysis$10(CheckAnalysis.scala:259)
    [info]   at org.apache.spark.sql.catalyst.analysis.CheckAnalysis.$anonfun$checkAnalysis$10$adapted(CheckAnalysis.scala:259)
    [info]   at scala.collection.mutable.ResizableArray.foreach(ResizableArray.scala:62)
    [info]   at scala.collection.mutable.ResizableArray.foreach$(ResizableArray.scala:55)
    [info]   at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:49)
    [info]   at org.apache.spark.sql.catalyst.analysis.CheckAnalysis.checkValidAggregateExpression$1(CheckAnalysis.scala:259)
    [info]   at org.apache.spark.sql.catalyst.analysis.CheckAnalysis.$anonfun$checkAnalysis$10(CheckAnalysis.scala:259)
    [info]   at org.apache.spark.sql.catalyst.analysis.CheckAnalysis.$anonfun$checkAnalysis$10$adapted(CheckAnalysis.scala:259)
    [info]   at scala.collection.immutable.List.foreach(List.scala:392)
    [info]   at org.apache.spark.sql.catalyst.analysis.CheckAnalysis.checkValidAggregateExpression$1(CheckAnalysis.scala:259)
    ...
    ```
    
    We use the rule`ResolveEncodersInUDF` to resolve `inputEncoders` and the original`ScalaUDF` instance will be updated to a new `ScalaUDF` instance with the resolved encoders at the end. Note, during encoder resolving, types like `map`, `array` will be resolved to new expression(e.g. `MapObjects`, `CatalystToExternalMap`).
    
    However, `ExpressionEncoder` can't be canonicalized. Thus, the canonicalized `ScalaUDF`s become different even if their original  `ScalaUDF`s are the same. Finally, it fails the `checkValidAggregateExpression` when this `ScalaUDF` is used as a group expression.
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes, users will not hit the exception after this fix.
    
    ### How was this patch tested?
    
    Added tests.
    
    Closes #29106 from Ngone51/spark-32307.
    
    Authored-by: yi.wu <yi...@databricks.com>
    Signed-off-by: Dongjoon Hyun <do...@apache.org>
---
 .../org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala |  6 ++++++
 sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala  | 12 ++++++++++++
 2 files changed, 18 insertions(+)

diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala
index 44ee06a..6e2bd96 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala
@@ -59,6 +59,12 @@ case class ScalaUDF(
 
   override def toString: String = s"${udfName.getOrElse("UDF")}(${children.mkString(", ")})"
 
+  override lazy val canonicalized: Expression = {
+    // SPARK-32307: `ExpressionEncoder` can't be canonicalized, and technically we don't
+    // need it to identify a `ScalaUDF`.
+    Canonicalize.execute(copy(children = children.map(_.canonicalized), inputEncoders = Nil))
+  }
+
   /**
    * The analyzer should be aware of Scala primitive types so as to make the
    * UDF return null if there is any null input value of these types. On the
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala
index 05a33f9..f0d5a61 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala
@@ -775,4 +775,16 @@ class UDFSuite extends QueryTest with SharedSparkSession {
     }
     assert(e2.getMessage.contains("UDFSuite$MalformedClassObject$MalformedPrimitiveFunction"))
   }
+
+  test("SPARK-32307: Aggression that use map type input UDF as group expression") {
+    spark.udf.register("key", udf((m: Map[String, String]) => m.keys.head.toInt))
+    Seq(Map("1" -> "one", "2" -> "two")).toDF("a").createOrReplaceTempView("t")
+    checkAnswer(sql("SELECT key(a) AS k FROM t GROUP BY key(a)"), Row(1) :: Nil)
+  }
+
+  test("SPARK-32307: Aggression that use array type input UDF as group expression") {
+    spark.udf.register("key", udf((m: Array[Int]) => m.head))
+    Seq(Array(1)).toDF("a").createOrReplaceTempView("t")
+    checkAnswer(sql("SELECT key(a) AS k FROM t GROUP BY key(a)"), Row(1) :: Nil)
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org