You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by do...@apache.org on 2021/01/11 19:32:35 UTC

[spark] branch branch-3.1 updated: [SPARK-34002][SQL] Fix the usage of encoder in ScalaUDF

This is an automated email from the ASF dual-hosted git repository.

dongjoon pushed a commit to branch branch-3.1
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.1 by this push:
     new b490cb8  [SPARK-34002][SQL] Fix the usage of encoder in ScalaUDF
b490cb8 is described below

commit b490cb8e9cf6d8315a4154423737227faab041fc
Author: Liang-Chi Hsieh <vi...@gmail.com>
AuthorDate: Mon Jan 11 11:31:35 2021 -0800

    [SPARK-34002][SQL] Fix the usage of encoder in ScalaUDF
    
    ### What changes were proposed in this pull request?
    
    This patch fixes few issues when using encoders to serialize input/output in `ScalaUDF`.
    
    ### Why are the changes needed?
    
    This fixes a bug when using encoders in Scala UDF. First, the output data type should be corrected to the corresponding data type of the object serializer. Second, `catalystConverter` should not serialize `Option[_]` as the ordinary row because in `ScalaUDF` case it is serialized to a column, not the top-level row. Otherwise, there will be a redundant `value` struct wrapping the serialized `Option[_]` object.
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes, fixing a bug of `ScalaUDF`.
    
    ### How was this patch tested?
    
    Unit test.
    
    Closes #31103 from viirya/SPARK-34002.
    
    Authored-by: Liang-Chi Hsieh <vi...@gmail.com>
    Signed-off-by: Dongjoon Hyun <dh...@apple.com>
    (cherry picked from commit 0bcbafb4b868234d695b553f215e21fe66e7ff77)
    Signed-off-by: Dongjoon Hyun <dh...@apple.com>
---
 .../sql/catalyst/encoders/ExpressionEncoder.scala  |  6 --
 .../spark/sql/catalyst/expressions/ScalaUDF.scala  |  2 +-
 .../org/apache/spark/sql/UDFRegistration.scala     | 64 ++++++++++++++--------
 .../scala/org/apache/spark/sql/functions.scala     | 24 ++++----
 .../scala/org/apache/spark/sql/DatasetSuite.scala  | 27 +++++++++
 5 files changed, 80 insertions(+), 43 deletions(-)

diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
index 9ab3804..8635b1c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
@@ -22,7 +22,6 @@ import scala.reflect.runtime.universe.{typeTag, TypeTag}
 
 import org.apache.spark.sql.Encoder
 import org.apache.spark.sql.catalyst.{InternalRow, JavaTypeInference, ScalaReflection}
-import org.apache.spark.sql.catalyst.ScalaReflection.Schema
 import org.apache.spark.sql.catalyst.analysis.{Analyzer, GetColumnByOrdinal, SimpleAnalyzer, UnresolvedAttribute, UnresolvedExtractValue}
 import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder.{Deserializer, Serializer}
 import org.apache.spark.sql.catalyst.expressions._
@@ -304,11 +303,6 @@ case class ExpressionEncoder[T](
     StructField(s.name, s.dataType, s.nullable)
   })
 
-  def dataTypeAndNullable: Schema = {
-    val dataType = if (isSerializedAsStruct) schema else schema.head.dataType
-    Schema(dataType, objSerializer.nullable)
-  }
-
   /**
    * Returns true if the type `T` is serialized as a struct by `objSerializer`.
    */
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala
index 0a69d5a..9f8a8b1 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala
@@ -120,7 +120,7 @@ case class ScalaUDF(
    */
   private def catalystConverter: Any => Any = outputEncoder.map { enc =>
     val toRow = enc.createSerializer().asInstanceOf[Any => Any]
-    if (enc.isSerializedAsStruct) {
+    if (enc.isSerializedAsStructForTopLevel) {
       value: Any =>
         if (value == null) null else toRow(value).asInstanceOf[InternalRow]
     } else {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala
index 237cfe1..7567ed6 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala
@@ -49,6 +49,8 @@ import org.apache.spark.util.Utils
 @Stable
 class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends Logging {
 
+  import UDFRegistration._
+
   protected[sql] def registerPython(name: String, udf: UserDefinedPythonFunction): Unit = {
     log.debug(
       s"""
@@ -135,7 +137,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
         | */
         |def register[$typeTags](name: String, func: Function$x[$types]): UserDefinedFunction = {
         |  val outputEncoder = Try(ExpressionEncoder[RT]()).toOption
-        |  val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(_.dataTypeAndNullable).getOrElse(ScalaReflection.schemaFor[RT])
+        |  val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(outputSchema).getOrElse(ScalaReflection.schemaFor[RT])
         |  val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = $inputEncoders
         |  val udf = SparkUserDefinedFunction(func, dataType, inputEncoders, outputEncoder).withName(name)
         |  val finalUdf = if (nullable) udf else udf.asNonNullable()
@@ -183,7 +185,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
    */
   def register[RT: TypeTag](name: String, func: Function0[RT]): UserDefinedFunction = {
     val outputEncoder = Try(ExpressionEncoder[RT]()).toOption
-    val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(_.dataTypeAndNullable).getOrElse(ScalaReflection.schemaFor[RT])
+    val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(outputSchema).getOrElse(ScalaReflection.schemaFor[RT])
     val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = Nil
     val udf = SparkUserDefinedFunction(func, dataType, inputEncoders, outputEncoder).withName(name)
     val finalUdf = if (nullable) udf else udf.asNonNullable()
@@ -204,7 +206,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
    */
   def register[RT: TypeTag, A1: TypeTag](name: String, func: Function1[A1, RT]): UserDefinedFunction = {
     val outputEncoder = Try(ExpressionEncoder[RT]()).toOption
-    val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(_.dataTypeAndNullable).getOrElse(ScalaReflection.schemaFor[RT])
+    val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(outputSchema).getOrElse(ScalaReflection.schemaFor[RT])
     val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = Try(ExpressionEncoder[A1]()).toOption :: Nil
     val udf = SparkUserDefinedFunction(func, dataType, inputEncoders, outputEncoder).withName(name)
     val finalUdf = if (nullable) udf else udf.asNonNullable()
@@ -225,7 +227,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
    */
   def register[RT: TypeTag, A1: TypeTag, A2: TypeTag](name: String, func: Function2[A1, A2, RT]): UserDefinedFunction = {
     val outputEncoder = Try(ExpressionEncoder[RT]()).toOption
-    val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(_.dataTypeAndNullable).getOrElse(ScalaReflection.schemaFor[RT])
+    val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(outputSchema).getOrElse(ScalaReflection.schemaFor[RT])
     val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Nil
     val udf = SparkUserDefinedFunction(func, dataType, inputEncoders, outputEncoder).withName(name)
     val finalUdf = if (nullable) udf else udf.asNonNullable()
@@ -246,7 +248,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
    */
   def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag](name: String, func: Function3[A1, A2, A3, RT]): UserDefinedFunction = {
     val outputEncoder = Try(ExpressionEncoder[RT]()).toOption
-    val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(_.dataTypeAndNullable).getOrElse(ScalaReflection.schemaFor[RT])
+    val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(outputSchema).getOrElse(ScalaReflection.schemaFor[RT])
     val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Try(ExpressionEncoder[A3]()).toOption :: Nil
     val udf = SparkUserDefinedFunction(func, dataType, inputEncoders, outputEncoder).withName(name)
     val finalUdf = if (nullable) udf else udf.asNonNullable()
@@ -267,7 +269,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
    */
   def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag](name: String, func: Function4[A1, A2, A3, A4, RT]): UserDefinedFunction = {
     val outputEncoder = Try(ExpressionEncoder[RT]()).toOption
-    val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(_.dataTypeAndNullable).getOrElse(ScalaReflection.schemaFor[RT])
+    val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(outputSchema).getOrElse(ScalaReflection.schemaFor[RT])
     val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Try(ExpressionEncoder[A3]()).toOption :: Try(ExpressionEncoder[A4]()).toOption :: Nil
     val udf = SparkUserDefinedFunction(func, dataType, inputEncoders, outputEncoder).withName(name)
     val finalUdf = if (nullable) udf else udf.asNonNullable()
@@ -288,7 +290,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
    */
   def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag](name: String, func: Function5[A1, A2, A3, A4, A5, RT]): UserDefinedFunction = {
     val outputEncoder = Try(ExpressionEncoder[RT]()).toOption
-    val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(_.dataTypeAndNullable).getOrElse(ScalaReflection.schemaFor[RT])
+    val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(outputSchema).getOrElse(ScalaReflection.schemaFor[RT])
     val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Try(ExpressionEncoder[A3]()).toOption :: Try(ExpressionEncoder[A4]()).toOption :: Try(ExpressionEncoder[A5]()).toOption :: Nil
     val udf = SparkUserDefinedFunction(func, dataType, inputEncoders, outputEncoder).withName(name)
     val finalUdf = if (nullable) udf else udf.asNonNullable()
@@ -309,7 +311,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
    */
   def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag](name: String, func: Function6[A1, A2, A3, A4, A5, A6, RT]): UserDefinedFunction = {
     val outputEncoder = Try(ExpressionEncoder[RT]()).toOption
-    val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(_.dataTypeAndNullable).getOrElse(ScalaReflection.schemaFor[RT])
+    val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(outputSchema).getOrElse(ScalaReflection.schemaFor[RT])
     val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Try(ExpressionEncoder[A3]()).toOption :: Try(ExpressionEncoder[A4]()).toOption :: Try(ExpressionEncoder[A5]()).toOption :: Try(ExpressionEncoder[A6]()).toOption :: Nil
     val udf = SparkUserDefinedFunction(func, dataType, inputEncoders, outputEncoder).withName(name)
     val finalUdf = if (nullable) udf else udf.asNonNullable()
@@ -330,7 +332,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
    */
   def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag](name: String, func: Function7[A1, A2, A3, A4, A5, A6, A7, RT]): UserDefinedFunction = {
     val outputEncoder = Try(ExpressionEncoder[RT]()).toOption
-    val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(_.dataTypeAndNullable).getOrElse(ScalaReflection.schemaFor[RT])
+    val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(outputSchema).getOrElse(ScalaReflection.schemaFor[RT])
     val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Try(ExpressionEncoder[A3]()).toOption :: Try(ExpressionEncoder[A4]()).toOption :: Try(ExpressionEncoder[A5]()).toOption :: Try(ExpressionEncoder[A6]()).toOption :: Try(ExpressionEncoder[A7]()).toOption :: Nil
     val udf = SparkUserDefinedFunction(func, dataType, inputEncoders, outputEncoder).withName(name)
     val finalUdf = if (nullable) udf else udf.asNonNullable()
@@ -351,7 +353,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
    */
   def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag](name: String, func: Function8[A1, A2, A3, A4, A5, A6, A7, A8, RT]): UserDefinedFunction = {
     val outputEncoder = Try(ExpressionEncoder[RT]()).toOption
-    val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(_.dataTypeAndNullable).getOrElse(ScalaReflection.schemaFor[RT])
+    val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(outputSchema).getOrElse(ScalaReflection.schemaFor[RT])
     val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Try(ExpressionEncoder[A3]()).toOption :: Try(ExpressionEncoder[A4]()).toOption :: Try(ExpressionEncoder[A5]()).toOption :: Try(ExpressionEncoder[A6]()).toOption :: Try(ExpressionEncoder[A7]()).toOption :: Try(ExpressionEncoder[A8]()).toOption :: Nil
     val udf = SparkUserDefinedFunction(func, dataType, inputEncoders, outputEncoder).withName(name)
     val finalUdf = if (nullable) udf else udf.asNonNullable()
@@ -372,7 +374,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
    */
   def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag](name: String, func: Function9[A1, A2, A3, A4, A5, A6, A7, A8, A9, RT]): UserDefinedFunction = {
     val outputEncoder = Try(ExpressionEncoder[RT]()).toOption
-    val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(_.dataTypeAndNullable).getOrElse(ScalaReflection.schemaFor[RT])
+    val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(outputSchema).getOrElse(ScalaReflection.schemaFor[RT])
     val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Try(ExpressionEncoder[A3]()).toOption :: Try(ExpressionEncoder[A4]()).toOption :: Try(ExpressionEncoder[A5]()).toOption :: Try(ExpressionEncoder[A6]()).toOption :: Try(ExpressionEncoder[A7]()).toOption :: Try(ExpressionEncoder[A8]()).toOption :: Try(ExpressionEncoder[A9]()).toOption :: Nil
     val udf = SparkUserDefinedFunction(func, dataType, inputEncoders, outputEncoder).withName(name)
     val finalUdf = if (nullable) udf else udf.asNonNullable()
@@ -393,7 +395,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
    */
   def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag](name: String, func: Function10[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, RT]): UserDefinedFunction = {
     val outputEncoder = Try(ExpressionEncoder[RT]()).toOption
-    val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(_.dataTypeAndNullable).getOrElse(ScalaReflection.schemaFor[RT])
+    val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(outputSchema).getOrElse(ScalaReflection.schemaFor[RT])
     val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Try(ExpressionEncoder[A3]()).toOption :: Try(ExpressionEncoder[A4]()).toOption :: Try(ExpressionEncoder[A5]()).toOption :: Try(ExpressionEncoder[A6]()).toOption :: Try(ExpressionEncoder[A7]()).toOption :: Try(ExpressionEncoder[A8]()).toOption :: Try(ExpressionEncoder[A9]()).toOption :: Try(ExpressionEncoder[A10]()).toOption :: Nil
     val udf = SparkUserDefinedFunction(func, dataType, inputEncoders, outputEncoder).withName(name)
     val finalUdf = if (nullable) udf else udf.asNonNullable()
@@ -414,7 +416,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
    */
   def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag](name: String, func: Function11[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, RT]): UserDefinedFunction = {
     val outputEncoder = Try(ExpressionEncoder[RT]()).toOption
-    val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(_.dataTypeAndNullable).getOrElse(ScalaReflection.schemaFor[RT])
+    val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(outputSchema).getOrElse(ScalaReflection.schemaFor[RT])
     val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Try(ExpressionEncoder[A3]()).toOption :: Try(ExpressionEncoder[A4]()).toOption :: Try(ExpressionEncoder[A5]()).toOption :: Try(ExpressionEncoder[A6]()).toOption :: Try(ExpressionEncoder[A7]()).toOption :: Try(ExpressionEncoder[A8]()).toOption :: Try(ExpressionEncoder[A9]()).toOption :: Try(ExpressionEncoder[A10]()).toOption :: Try(ExpressionEncoder[ [...]
     val udf = SparkUserDefinedFunction(func, dataType, inputEncoders, outputEncoder).withName(name)
     val finalUdf = if (nullable) udf else udf.asNonNullable()
@@ -435,7 +437,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
    */
   def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag](name: String, func: Function12[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, RT]): UserDefinedFunction = {
     val outputEncoder = Try(ExpressionEncoder[RT]()).toOption
-    val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(_.dataTypeAndNullable).getOrElse(ScalaReflection.schemaFor[RT])
+    val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(outputSchema).getOrElse(ScalaReflection.schemaFor[RT])
     val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Try(ExpressionEncoder[A3]()).toOption :: Try(ExpressionEncoder[A4]()).toOption :: Try(ExpressionEncoder[A5]()).toOption :: Try(ExpressionEncoder[A6]()).toOption :: Try(ExpressionEncoder[A7]()).toOption :: Try(ExpressionEncoder[A8]()).toOption :: Try(ExpressionEncoder[A9]()).toOption :: Try(ExpressionEncoder[A10]()).toOption :: Try(ExpressionEncoder[ [...]
     val udf = SparkUserDefinedFunction(func, dataType, inputEncoders, outputEncoder).withName(name)
     val finalUdf = if (nullable) udf else udf.asNonNullable()
@@ -456,7 +458,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
    */
   def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag](name: String, func: Function13[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, RT]): UserDefinedFunction = {
     val outputEncoder = Try(ExpressionEncoder[RT]()).toOption
-    val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(_.dataTypeAndNullable).getOrElse(ScalaReflection.schemaFor[RT])
+    val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(outputSchema).getOrElse(ScalaReflection.schemaFor[RT])
     val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Try(ExpressionEncoder[A3]()).toOption :: Try(ExpressionEncoder[A4]()).toOption :: Try(ExpressionEncoder[A5]()).toOption :: Try(ExpressionEncoder[A6]()).toOption :: Try(ExpressionEncoder[A7]()).toOption :: Try(ExpressionEncoder[A8]()).toOption :: Try(ExpressionEncoder[A9]()).toOption :: Try(ExpressionEncoder[A10]()).toOption :: Try(ExpressionEncoder[ [...]
     val udf = SparkUserDefinedFunction(func, dataType, inputEncoders, outputEncoder).withName(name)
     val finalUdf = if (nullable) udf else udf.asNonNullable()
@@ -477,7 +479,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
    */
   def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag](name: String, func: Function14[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, RT]): UserDefinedFunction = {
     val outputEncoder = Try(ExpressionEncoder[RT]()).toOption
-    val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(_.dataTypeAndNullable).getOrElse(ScalaReflection.schemaFor[RT])
+    val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(outputSchema).getOrElse(ScalaReflection.schemaFor[RT])
     val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Try(ExpressionEncoder[A3]()).toOption :: Try(ExpressionEncoder[A4]()).toOption :: Try(ExpressionEncoder[A5]()).toOption :: Try(ExpressionEncoder[A6]()).toOption :: Try(ExpressionEncoder[A7]()).toOption :: Try(ExpressionEncoder[A8]()).toOption :: Try(ExpressionEncoder[A9]()).toOption :: Try(ExpressionEncoder[A10]()).toOption :: Try(ExpressionEncoder[ [...]
     val udf = SparkUserDefinedFunction(func, dataType, inputEncoders, outputEncoder).withName(name)
     val finalUdf = if (nullable) udf else udf.asNonNullable()
@@ -498,7 +500,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
    */
   def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag](name: String, func: Function15[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, RT]): UserDefinedFunction = {
     val outputEncoder = Try(ExpressionEncoder[RT]()).toOption
-    val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(_.dataTypeAndNullable).getOrElse(ScalaReflection.schemaFor[RT])
+    val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(outputSchema).getOrElse(ScalaReflection.schemaFor[RT])
     val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Try(ExpressionEncoder[A3]()).toOption :: Try(ExpressionEncoder[A4]()).toOption :: Try(ExpressionEncoder[A5]()).toOption :: Try(ExpressionEncoder[A6]()).toOption :: Try(ExpressionEncoder[A7]()).toOption :: Try(ExpressionEncoder[A8]()).toOption :: Try(ExpressionEncoder[A9]()).toOption :: Try(ExpressionEncoder[A10]()).toOption :: Try(ExpressionEncoder[ [...]
     val udf = SparkUserDefinedFunction(func, dataType, inputEncoders, outputEncoder).withName(name)
     val finalUdf = if (nullable) udf else udf.asNonNullable()
@@ -519,7 +521,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
    */
   def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag](name: String, func: Function16[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, RT]): UserDefinedFunction = {
     val outputEncoder = Try(ExpressionEncoder[RT]()).toOption
-    val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(_.dataTypeAndNullable).getOrElse(ScalaReflection.schemaFor[RT])
+    val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(outputSchema).getOrElse(ScalaReflection.schemaFor[RT])
     val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Try(ExpressionEncoder[A3]()).toOption :: Try(ExpressionEncoder[A4]()).toOption :: Try(ExpressionEncoder[A5]()).toOption :: Try(ExpressionEncoder[A6]()).toOption :: Try(ExpressionEncoder[A7]()).toOption :: Try(ExpressionEncoder[A8]()).toOption :: Try(ExpressionEncoder[A9]()).toOption :: Try(ExpressionEncoder[A10]()).toOption :: Try(ExpressionEncoder[ [...]
     val udf = SparkUserDefinedFunction(func, dataType, inputEncoders, outputEncoder).withName(name)
     val finalUdf = if (nullable) udf else udf.asNonNullable()
@@ -540,7 +542,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
    */
   def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag](name: String, func: Function17[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, RT]): UserDefinedFunction = {
     val outputEncoder = Try(ExpressionEncoder[RT]()).toOption
-    val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(_.dataTypeAndNullable).getOrElse(ScalaReflection.schemaFor[RT])
+    val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(outputSchema).getOrElse(ScalaReflection.schemaFor[RT])
     val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Try(ExpressionEncoder[A3]()).toOption :: Try(ExpressionEncoder[A4]()).toOption :: Try(ExpressionEncoder[A5]()).toOption :: Try(ExpressionEncoder[A6]()).toOption :: Try(ExpressionEncoder[A7]()).toOption :: Try(ExpressionEncoder[A8]()).toOption :: Try(ExpressionEncoder[A9]()).toOption :: Try(ExpressionEncoder[A10]()).toOption :: Try(ExpressionEncoder[ [...]
     val udf = SparkUserDefinedFunction(func, dataType, inputEncoders, outputEncoder).withName(name)
     val finalUdf = if (nullable) udf else udf.asNonNullable()
@@ -561,7 +563,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
    */
   def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag](name: String, func: Function18[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, RT]): UserDefinedFunction = {
     val outputEncoder = Try(ExpressionEncoder[RT]()).toOption
-    val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(_.dataTypeAndNullable).getOrElse(ScalaReflection.schemaFor[RT])
+    val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(outputSchema).getOrElse(ScalaReflection.schemaFor[RT])
     val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Try(ExpressionEncoder[A3]()).toOption :: Try(ExpressionEncoder[A4]()).toOption :: Try(ExpressionEncoder[A5]()).toOption :: Try(ExpressionEncoder[A6]()).toOption :: Try(ExpressionEncoder[A7]()).toOption :: Try(ExpressionEncoder[A8]()).toOption :: Try(ExpressionEncoder[A9]()).toOption :: Try(ExpressionEncoder[A10]()).toOption :: Try(ExpressionEncoder[ [...]
     val udf = SparkUserDefinedFunction(func, dataType, inputEncoders, outputEncoder).withName(name)
     val finalUdf = if (nullable) udf else udf.asNonNullable()
@@ -582,7 +584,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
    */
   def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag, A19: TypeTag](name: String, func: Function19[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, A19, RT]): UserDefinedFunction = {
     val outputEncoder = Try(ExpressionEncoder[RT]()).toOption
-    val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(_.dataTypeAndNullable).getOrElse(ScalaReflection.schemaFor[RT])
+    val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(outputSchema).getOrElse(ScalaReflection.schemaFor[RT])
     val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Try(ExpressionEncoder[A3]()).toOption :: Try(ExpressionEncoder[A4]()).toOption :: Try(ExpressionEncoder[A5]()).toOption :: Try(ExpressionEncoder[A6]()).toOption :: Try(ExpressionEncoder[A7]()).toOption :: Try(ExpressionEncoder[A8]()).toOption :: Try(ExpressionEncoder[A9]()).toOption :: Try(ExpressionEncoder[A10]()).toOption :: Try(ExpressionEncoder[ [...]
     val udf = SparkUserDefinedFunction(func, dataType, inputEncoders, outputEncoder).withName(name)
     val finalUdf = if (nullable) udf else udf.asNonNullable()
@@ -603,7 +605,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
    */
   def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag, A19: TypeTag, A20: TypeTag](name: String, func: Function20[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, A19, A20, RT]): UserDefinedFunction = {
     val outputEncoder = Try(ExpressionEncoder[RT]()).toOption
-    val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(_.dataTypeAndNullable).getOrElse(ScalaReflection.schemaFor[RT])
+    val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(outputSchema).getOrElse(ScalaReflection.schemaFor[RT])
     val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Try(ExpressionEncoder[A3]()).toOption :: Try(ExpressionEncoder[A4]()).toOption :: Try(ExpressionEncoder[A5]()).toOption :: Try(ExpressionEncoder[A6]()).toOption :: Try(ExpressionEncoder[A7]()).toOption :: Try(ExpressionEncoder[A8]()).toOption :: Try(ExpressionEncoder[A9]()).toOption :: Try(ExpressionEncoder[A10]()).toOption :: Try(ExpressionEncoder[ [...]
     val udf = SparkUserDefinedFunction(func, dataType, inputEncoders, outputEncoder).withName(name)
     val finalUdf = if (nullable) udf else udf.asNonNullable()
@@ -624,7 +626,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
    */
   def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag, A19: TypeTag, A20: TypeTag, A21: TypeTag](name: String, func: Function21[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, A19, A20, A21, RT]): UserDefinedFunction = {
     val outputEncoder = Try(ExpressionEncoder[RT]()).toOption
-    val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(_.dataTypeAndNullable).getOrElse(ScalaReflection.schemaFor[RT])
+    val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(outputSchema).getOrElse(ScalaReflection.schemaFor[RT])
     val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Try(ExpressionEncoder[A3]()).toOption :: Try(ExpressionEncoder[A4]()).toOption :: Try(ExpressionEncoder[A5]()).toOption :: Try(ExpressionEncoder[A6]()).toOption :: Try(ExpressionEncoder[A7]()).toOption :: Try(ExpressionEncoder[A8]()).toOption :: Try(ExpressionEncoder[A9]()).toOption :: Try(ExpressionEncoder[A10]()).toOption :: Try(ExpressionEncoder[ [...]
     val udf = SparkUserDefinedFunction(func, dataType, inputEncoders, outputEncoder).withName(name)
     val finalUdf = if (nullable) udf else udf.asNonNullable()
@@ -645,7 +647,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
    */
   def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag, A19: TypeTag, A20: TypeTag, A21: TypeTag, A22: TypeTag](name: String, func: Function22[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, A19, A20, A21, A22, RT]): UserDefinedFunction = {
     val outputEncoder = Try(ExpressionEncoder[RT]()).toOption
-    val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(_.dataTypeAndNullable).getOrElse(ScalaReflection.schemaFor[RT])
+    val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(outputSchema).getOrElse(ScalaReflection.schemaFor[RT])
     val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Try(ExpressionEncoder[A3]()).toOption :: Try(ExpressionEncoder[A4]()).toOption :: Try(ExpressionEncoder[A5]()).toOption :: Try(ExpressionEncoder[A6]()).toOption :: Try(ExpressionEncoder[A7]()).toOption :: Try(ExpressionEncoder[A8]()).toOption :: Try(ExpressionEncoder[A9]()).toOption :: Try(ExpressionEncoder[A10]()).toOption :: Try(ExpressionEncoder[ [...]
     val udf = SparkUserDefinedFunction(func, dataType, inputEncoders, outputEncoder).withName(name)
     val finalUdf = if (nullable) udf else udf.asNonNullable()
@@ -1121,3 +1123,17 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
   // scalastyle:on line.size.limit
 
 }
+
+private[sql] object UDFRegistration {
+  /**
+   * Obtaining the schema of output encoder for `ScalaUDF`.
+   *
+   * As the serialization in `ScalaUDF` is for individual column, not the whole row,
+   * we just take the data type of vanilla object serializer, not `serializer` which
+   * is transformed somehow for top-level row.
+   */
+  def outputSchema(outputEncoder: ExpressionEncoder[_]): ScalaReflection.Schema = {
+    ScalaReflection.Schema(outputEncoder.objSerializer.dataType,
+      outputEncoder.objSerializer.nullable)
+  }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index 5b1ee2d..ceb211c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -4553,7 +4553,7 @@ object functions {
       | */
       |def udf[$typeTags](f: Function$x[$types]): UserDefinedFunction = {
       |  val outputEncoder = Try(ExpressionEncoder[RT]()).toOption
-      |  val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(_.dataTypeAndNullable).getOrElse(ScalaReflection.schemaFor[RT])
+      |  val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(UDFRegistration.outputSchema).getOrElse(ScalaReflection.schemaFor[RT])
       |  val inputEncoders = $inputEncoders
       |  val udf = SparkUserDefinedFunction(f, dataType, inputEncoders, outputEncoder)
       |  if (nullable) udf else udf.asNonNullable()
@@ -4660,7 +4660,7 @@ object functions {
    */
   def udf[RT: TypeTag](f: Function0[RT]): UserDefinedFunction = {
     val outputEncoder = Try(ExpressionEncoder[RT]()).toOption
-    val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(_.dataTypeAndNullable).getOrElse(ScalaReflection.schemaFor[RT])
+    val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(UDFRegistration.outputSchema).getOrElse(ScalaReflection.schemaFor[RT])
     val inputEncoders = Nil
     val udf = SparkUserDefinedFunction(f, dataType, inputEncoders, outputEncoder)
     if (nullable) udf else udf.asNonNullable()
@@ -4677,7 +4677,7 @@ object functions {
    */
   def udf[RT: TypeTag, A1: TypeTag](f: Function1[A1, RT]): UserDefinedFunction = {
     val outputEncoder = Try(ExpressionEncoder[RT]()).toOption
-    val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(_.dataTypeAndNullable).getOrElse(ScalaReflection.schemaFor[RT])
+    val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(UDFRegistration.outputSchema).getOrElse(ScalaReflection.schemaFor[RT])
     val inputEncoders = Try(ExpressionEncoder[A1]()).toOption :: Nil
     val udf = SparkUserDefinedFunction(f, dataType, inputEncoders, outputEncoder)
     if (nullable) udf else udf.asNonNullable()
@@ -4694,7 +4694,7 @@ object functions {
    */
   def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag](f: Function2[A1, A2, RT]): UserDefinedFunction = {
     val outputEncoder = Try(ExpressionEncoder[RT]()).toOption
-    val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(_.dataTypeAndNullable).getOrElse(ScalaReflection.schemaFor[RT])
+    val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(UDFRegistration.outputSchema).getOrElse(ScalaReflection.schemaFor[RT])
     val inputEncoders = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Nil
     val udf = SparkUserDefinedFunction(f, dataType, inputEncoders, outputEncoder)
     if (nullable) udf else udf.asNonNullable()
@@ -4711,7 +4711,7 @@ object functions {
    */
   def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag](f: Function3[A1, A2, A3, RT]): UserDefinedFunction = {
     val outputEncoder = Try(ExpressionEncoder[RT]()).toOption
-    val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(_.dataTypeAndNullable).getOrElse(ScalaReflection.schemaFor[RT])
+    val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(UDFRegistration.outputSchema).getOrElse(ScalaReflection.schemaFor[RT])
     val inputEncoders = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Try(ExpressionEncoder[A3]()).toOption :: Nil
     val udf = SparkUserDefinedFunction(f, dataType, inputEncoders, outputEncoder)
     if (nullable) udf else udf.asNonNullable()
@@ -4728,7 +4728,7 @@ object functions {
    */
   def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag](f: Function4[A1, A2, A3, A4, RT]): UserDefinedFunction = {
     val outputEncoder = Try(ExpressionEncoder[RT]()).toOption
-    val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(_.dataTypeAndNullable).getOrElse(ScalaReflection.schemaFor[RT])
+    val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(UDFRegistration.outputSchema).getOrElse(ScalaReflection.schemaFor[RT])
     val inputEncoders = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Try(ExpressionEncoder[A3]()).toOption :: Try(ExpressionEncoder[A4]()).toOption :: Nil
     val udf = SparkUserDefinedFunction(f, dataType, inputEncoders, outputEncoder)
     if (nullable) udf else udf.asNonNullable()
@@ -4745,7 +4745,7 @@ object functions {
    */
   def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag](f: Function5[A1, A2, A3, A4, A5, RT]): UserDefinedFunction = {
     val outputEncoder = Try(ExpressionEncoder[RT]()).toOption
-    val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(_.dataTypeAndNullable).getOrElse(ScalaReflection.schemaFor[RT])
+    val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(UDFRegistration.outputSchema).getOrElse(ScalaReflection.schemaFor[RT])
     val inputEncoders = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Try(ExpressionEncoder[A3]()).toOption :: Try(ExpressionEncoder[A4]()).toOption :: Try(ExpressionEncoder[A5]()).toOption :: Nil
     val udf = SparkUserDefinedFunction(f, dataType, inputEncoders, outputEncoder)
     if (nullable) udf else udf.asNonNullable()
@@ -4762,7 +4762,7 @@ object functions {
    */
   def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag](f: Function6[A1, A2, A3, A4, A5, A6, RT]): UserDefinedFunction = {
     val outputEncoder = Try(ExpressionEncoder[RT]()).toOption
-    val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(_.dataTypeAndNullable).getOrElse(ScalaReflection.schemaFor[RT])
+    val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(UDFRegistration.outputSchema).getOrElse(ScalaReflection.schemaFor[RT])
     val inputEncoders = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Try(ExpressionEncoder[A3]()).toOption :: Try(ExpressionEncoder[A4]()).toOption :: Try(ExpressionEncoder[A5]()).toOption :: Try(ExpressionEncoder[A6]()).toOption :: Nil
     val udf = SparkUserDefinedFunction(f, dataType, inputEncoders, outputEncoder)
     if (nullable) udf else udf.asNonNullable()
@@ -4779,7 +4779,7 @@ object functions {
    */
   def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag](f: Function7[A1, A2, A3, A4, A5, A6, A7, RT]): UserDefinedFunction = {
     val outputEncoder = Try(ExpressionEncoder[RT]()).toOption
-    val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(_.dataTypeAndNullable).getOrElse(ScalaReflection.schemaFor[RT])
+    val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(UDFRegistration.outputSchema).getOrElse(ScalaReflection.schemaFor[RT])
     val inputEncoders = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Try(ExpressionEncoder[A3]()).toOption :: Try(ExpressionEncoder[A4]()).toOption :: Try(ExpressionEncoder[A5]()).toOption :: Try(ExpressionEncoder[A6]()).toOption :: Try(ExpressionEncoder[A7]()).toOption :: Nil
     val udf = SparkUserDefinedFunction(f, dataType, inputEncoders, outputEncoder)
     if (nullable) udf else udf.asNonNullable()
@@ -4796,7 +4796,7 @@ object functions {
    */
   def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag](f: Function8[A1, A2, A3, A4, A5, A6, A7, A8, RT]): UserDefinedFunction = {
     val outputEncoder = Try(ExpressionEncoder[RT]()).toOption
-    val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(_.dataTypeAndNullable).getOrElse(ScalaReflection.schemaFor[RT])
+    val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(UDFRegistration.outputSchema).getOrElse(ScalaReflection.schemaFor[RT])
     val inputEncoders = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Try(ExpressionEncoder[A3]()).toOption :: Try(ExpressionEncoder[A4]()).toOption :: Try(ExpressionEncoder[A5]()).toOption :: Try(ExpressionEncoder[A6]()).toOption :: Try(ExpressionEncoder[A7]()).toOption :: Try(ExpressionEncoder[A8]()).toOption :: Nil
     val udf = SparkUserDefinedFunction(f, dataType, inputEncoders, outputEncoder)
     if (nullable) udf else udf.asNonNullable()
@@ -4813,7 +4813,7 @@ object functions {
    */
   def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag](f: Function9[A1, A2, A3, A4, A5, A6, A7, A8, A9, RT]): UserDefinedFunction = {
     val outputEncoder = Try(ExpressionEncoder[RT]()).toOption
-    val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(_.dataTypeAndNullable).getOrElse(ScalaReflection.schemaFor[RT])
+    val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(UDFRegistration.outputSchema).getOrElse(ScalaReflection.schemaFor[RT])
     val inputEncoders = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Try(ExpressionEncoder[A3]()).toOption :: Try(ExpressionEncoder[A4]()).toOption :: Try(ExpressionEncoder[A5]()).toOption :: Try(ExpressionEncoder[A6]()).toOption :: Try(ExpressionEncoder[A7]()).toOption :: Try(ExpressionEncoder[A8]()).toOption :: Try(ExpressionEncoder[A9]()).toOption :: Nil
     val udf = SparkUserDefinedFunction(f, dataType, inputEncoders, outputEncoder)
     if (nullable) udf else udf.asNonNullable()
@@ -4830,7 +4830,7 @@ object functions {
    */
   def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag](f: Function10[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, RT]): UserDefinedFunction = {
     val outputEncoder = Try(ExpressionEncoder[RT]()).toOption
-    val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(_.dataTypeAndNullable).getOrElse(ScalaReflection.schemaFor[RT])
+    val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(UDFRegistration.outputSchema).getOrElse(ScalaReflection.schemaFor[RT])
     val inputEncoders = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Try(ExpressionEncoder[A3]()).toOption :: Try(ExpressionEncoder[A4]()).toOption :: Try(ExpressionEncoder[A5]()).toOption :: Try(ExpressionEncoder[A6]()).toOption :: Try(ExpressionEncoder[A7]()).toOption :: Try(ExpressionEncoder[A8]()).toOption :: Try(ExpressionEncoder[A9]()).toOption :: Try(ExpressionEncoder[A10]()).toOption :: Nil
     val udf = SparkUserDefinedFunction(f, dataType, inputEncoders, outputEncoder)
     if (nullable) udf else udf.asNonNullable()
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
index 67e3ad6..69fbb9b 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
@@ -1955,8 +1955,35 @@ class DatasetSuite extends QueryTest
       assert(timezone == "Asia/Shanghai")
     }
   }
+
+  test("SPARK-34002: Fix broken Option input/output in UDF") {
+    def f1(bar: Bar): Option[Bar] = {
+      None
+    }
+
+    def f2(bar: Option[Bar]): Option[Bar] = {
+      bar
+    }
+
+    val udf1 = udf(f1 _).withName("f1")
+    val udf2 = udf(f2 _).withName("f2")
+
+    val df = (1 to 2).map(i => Tuple1(Bar(1))).toDF("c0")
+    val withUDF = df
+      .withColumn("c1", udf1(col("c0")))
+      .withColumn("c2", udf2(col("c1")))
+
+    assert(withUDF.schema == StructType(
+      StructField("c0", StructType(StructField("a", IntegerType, false) :: Nil)) ::
+        StructField("c1", StructType(StructField("a", IntegerType, false) :: Nil)) ::
+        StructField("c2", StructType(StructField("a", IntegerType, false) :: Nil)) :: Nil))
+
+    checkAnswer(withUDF, Row(Row(1), null, null) :: Row(Row(1), null, null) :: Nil)
+  }
 }
 
+case class Bar(a: Int)
+
 object AssertExecutionId {
   def apply(id: Long): Long = {
     assert(TaskContext.get().getLocalProperty(SQLExecution.EXECUTION_ID_KEY) != null)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org