You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by do...@apache.org on 2021/01/11 19:32:20 UTC
[spark] branch master updated: [SPARK-34002][SQL] Fix the usage of
encoder in ScalaUDF
This is an automated email from the ASF dual-hosted git repository.
dongjoon pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 0bcbafb [SPARK-34002][SQL] Fix the usage of encoder in ScalaUDF
0bcbafb is described below
commit 0bcbafb4b868234d695b553f215e21fe66e7ff77
Author: Liang-Chi Hsieh <vi...@gmail.com>
AuthorDate: Mon Jan 11 11:31:35 2021 -0800
[SPARK-34002][SQL] Fix the usage of encoder in ScalaUDF
### What changes were proposed in this pull request?
This patch fixes few issues when using encoders to serialize input/output in `ScalaUDF`.
### Why are the changes needed?
This fixes a bug when using encoders in Scala UDF. First, the output data type should be corrected to the corresponding data type of the object serializer. Second, `catalystConverter` should not serialize `Option[_]` as the ordinary row because in `ScalaUDF` case it is serialized to a column, not the top-level row. Otherwise, there will be a redundant `value` struct wrapping the serialized `Option[_]` object.
### Does this PR introduce _any_ user-facing change?
Yes, fixing a bug of `ScalaUDF`.
### How was this patch tested?
Unit test.
Closes #31103 from viirya/SPARK-34002.
Authored-by: Liang-Chi Hsieh <vi...@gmail.com>
Signed-off-by: Dongjoon Hyun <dh...@apple.com>
---
.../sql/catalyst/encoders/ExpressionEncoder.scala | 6 --
.../spark/sql/catalyst/expressions/ScalaUDF.scala | 2 +-
.../org/apache/spark/sql/UDFRegistration.scala | 64 ++++++++++++++--------
.../scala/org/apache/spark/sql/functions.scala | 24 ++++----
.../scala/org/apache/spark/sql/DatasetSuite.scala | 27 +++++++++
5 files changed, 80 insertions(+), 43 deletions(-)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
index 80a0374..665acdb 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
@@ -22,7 +22,6 @@ import scala.reflect.runtime.universe.{typeTag, TypeTag}
import org.apache.spark.sql.Encoder
import org.apache.spark.sql.catalyst.{InternalRow, JavaTypeInference, ScalaReflection}
-import org.apache.spark.sql.catalyst.ScalaReflection.Schema
import org.apache.spark.sql.catalyst.analysis.{Analyzer, GetColumnByOrdinal, SimpleAnalyzer, UnresolvedAttribute, UnresolvedExtractValue}
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder.{Deserializer, Serializer}
import org.apache.spark.sql.catalyst.expressions._
@@ -304,11 +303,6 @@ case class ExpressionEncoder[T](
StructField(s.name, s.dataType, s.nullable)
})
- def dataTypeAndNullable: Schema = {
- val dataType = if (isSerializedAsStruct) schema else schema.head.dataType
- Schema(dataType, objSerializer.nullable)
- }
-
/**
* Returns true if the type `T` is serialized as a struct by `objSerializer`.
*/
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala
index 4a89d24..8d25585 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala
@@ -120,7 +120,7 @@ case class ScalaUDF(
*/
private def catalystConverter: Any => Any = outputEncoder.map { enc =>
val toRow = enc.createSerializer().asInstanceOf[Any => Any]
- if (enc.isSerializedAsStruct) {
+ if (enc.isSerializedAsStructForTopLevel) {
value: Any =>
if (value == null) null else toRow(value).asInstanceOf[InternalRow]
} else {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala
index 237cfe1..7567ed6 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala
@@ -49,6 +49,8 @@ import org.apache.spark.util.Utils
@Stable
class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends Logging {
+ import UDFRegistration._
+
protected[sql] def registerPython(name: String, udf: UserDefinedPythonFunction): Unit = {
log.debug(
s"""
@@ -135,7 +137,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
| */
|def register[$typeTags](name: String, func: Function$x[$types]): UserDefinedFunction = {
| val outputEncoder = Try(ExpressionEncoder[RT]()).toOption
- | val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(_.dataTypeAndNullable).getOrElse(ScalaReflection.schemaFor[RT])
+ | val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(outputSchema).getOrElse(ScalaReflection.schemaFor[RT])
| val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = $inputEncoders
| val udf = SparkUserDefinedFunction(func, dataType, inputEncoders, outputEncoder).withName(name)
| val finalUdf = if (nullable) udf else udf.asNonNullable()
@@ -183,7 +185,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
*/
def register[RT: TypeTag](name: String, func: Function0[RT]): UserDefinedFunction = {
val outputEncoder = Try(ExpressionEncoder[RT]()).toOption
- val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(_.dataTypeAndNullable).getOrElse(ScalaReflection.schemaFor[RT])
+ val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(outputSchema).getOrElse(ScalaReflection.schemaFor[RT])
val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = Nil
val udf = SparkUserDefinedFunction(func, dataType, inputEncoders, outputEncoder).withName(name)
val finalUdf = if (nullable) udf else udf.asNonNullable()
@@ -204,7 +206,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
*/
def register[RT: TypeTag, A1: TypeTag](name: String, func: Function1[A1, RT]): UserDefinedFunction = {
val outputEncoder = Try(ExpressionEncoder[RT]()).toOption
- val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(_.dataTypeAndNullable).getOrElse(ScalaReflection.schemaFor[RT])
+ val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(outputSchema).getOrElse(ScalaReflection.schemaFor[RT])
val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = Try(ExpressionEncoder[A1]()).toOption :: Nil
val udf = SparkUserDefinedFunction(func, dataType, inputEncoders, outputEncoder).withName(name)
val finalUdf = if (nullable) udf else udf.asNonNullable()
@@ -225,7 +227,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
*/
def register[RT: TypeTag, A1: TypeTag, A2: TypeTag](name: String, func: Function2[A1, A2, RT]): UserDefinedFunction = {
val outputEncoder = Try(ExpressionEncoder[RT]()).toOption
- val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(_.dataTypeAndNullable).getOrElse(ScalaReflection.schemaFor[RT])
+ val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(outputSchema).getOrElse(ScalaReflection.schemaFor[RT])
val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Nil
val udf = SparkUserDefinedFunction(func, dataType, inputEncoders, outputEncoder).withName(name)
val finalUdf = if (nullable) udf else udf.asNonNullable()
@@ -246,7 +248,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
*/
def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag](name: String, func: Function3[A1, A2, A3, RT]): UserDefinedFunction = {
val outputEncoder = Try(ExpressionEncoder[RT]()).toOption
- val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(_.dataTypeAndNullable).getOrElse(ScalaReflection.schemaFor[RT])
+ val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(outputSchema).getOrElse(ScalaReflection.schemaFor[RT])
val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Try(ExpressionEncoder[A3]()).toOption :: Nil
val udf = SparkUserDefinedFunction(func, dataType, inputEncoders, outputEncoder).withName(name)
val finalUdf = if (nullable) udf else udf.asNonNullable()
@@ -267,7 +269,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
*/
def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag](name: String, func: Function4[A1, A2, A3, A4, RT]): UserDefinedFunction = {
val outputEncoder = Try(ExpressionEncoder[RT]()).toOption
- val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(_.dataTypeAndNullable).getOrElse(ScalaReflection.schemaFor[RT])
+ val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(outputSchema).getOrElse(ScalaReflection.schemaFor[RT])
val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Try(ExpressionEncoder[A3]()).toOption :: Try(ExpressionEncoder[A4]()).toOption :: Nil
val udf = SparkUserDefinedFunction(func, dataType, inputEncoders, outputEncoder).withName(name)
val finalUdf = if (nullable) udf else udf.asNonNullable()
@@ -288,7 +290,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
*/
def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag](name: String, func: Function5[A1, A2, A3, A4, A5, RT]): UserDefinedFunction = {
val outputEncoder = Try(ExpressionEncoder[RT]()).toOption
- val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(_.dataTypeAndNullable).getOrElse(ScalaReflection.schemaFor[RT])
+ val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(outputSchema).getOrElse(ScalaReflection.schemaFor[RT])
val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Try(ExpressionEncoder[A3]()).toOption :: Try(ExpressionEncoder[A4]()).toOption :: Try(ExpressionEncoder[A5]()).toOption :: Nil
val udf = SparkUserDefinedFunction(func, dataType, inputEncoders, outputEncoder).withName(name)
val finalUdf = if (nullable) udf else udf.asNonNullable()
@@ -309,7 +311,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
*/
def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag](name: String, func: Function6[A1, A2, A3, A4, A5, A6, RT]): UserDefinedFunction = {
val outputEncoder = Try(ExpressionEncoder[RT]()).toOption
- val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(_.dataTypeAndNullable).getOrElse(ScalaReflection.schemaFor[RT])
+ val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(outputSchema).getOrElse(ScalaReflection.schemaFor[RT])
val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Try(ExpressionEncoder[A3]()).toOption :: Try(ExpressionEncoder[A4]()).toOption :: Try(ExpressionEncoder[A5]()).toOption :: Try(ExpressionEncoder[A6]()).toOption :: Nil
val udf = SparkUserDefinedFunction(func, dataType, inputEncoders, outputEncoder).withName(name)
val finalUdf = if (nullable) udf else udf.asNonNullable()
@@ -330,7 +332,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
*/
def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag](name: String, func: Function7[A1, A2, A3, A4, A5, A6, A7, RT]): UserDefinedFunction = {
val outputEncoder = Try(ExpressionEncoder[RT]()).toOption
- val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(_.dataTypeAndNullable).getOrElse(ScalaReflection.schemaFor[RT])
+ val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(outputSchema).getOrElse(ScalaReflection.schemaFor[RT])
val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Try(ExpressionEncoder[A3]()).toOption :: Try(ExpressionEncoder[A4]()).toOption :: Try(ExpressionEncoder[A5]()).toOption :: Try(ExpressionEncoder[A6]()).toOption :: Try(ExpressionEncoder[A7]()).toOption :: Nil
val udf = SparkUserDefinedFunction(func, dataType, inputEncoders, outputEncoder).withName(name)
val finalUdf = if (nullable) udf else udf.asNonNullable()
@@ -351,7 +353,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
*/
def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag](name: String, func: Function8[A1, A2, A3, A4, A5, A6, A7, A8, RT]): UserDefinedFunction = {
val outputEncoder = Try(ExpressionEncoder[RT]()).toOption
- val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(_.dataTypeAndNullable).getOrElse(ScalaReflection.schemaFor[RT])
+ val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(outputSchema).getOrElse(ScalaReflection.schemaFor[RT])
val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Try(ExpressionEncoder[A3]()).toOption :: Try(ExpressionEncoder[A4]()).toOption :: Try(ExpressionEncoder[A5]()).toOption :: Try(ExpressionEncoder[A6]()).toOption :: Try(ExpressionEncoder[A7]()).toOption :: Try(ExpressionEncoder[A8]()).toOption :: Nil
val udf = SparkUserDefinedFunction(func, dataType, inputEncoders, outputEncoder).withName(name)
val finalUdf = if (nullable) udf else udf.asNonNullable()
@@ -372,7 +374,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
*/
def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag](name: String, func: Function9[A1, A2, A3, A4, A5, A6, A7, A8, A9, RT]): UserDefinedFunction = {
val outputEncoder = Try(ExpressionEncoder[RT]()).toOption
- val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(_.dataTypeAndNullable).getOrElse(ScalaReflection.schemaFor[RT])
+ val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(outputSchema).getOrElse(ScalaReflection.schemaFor[RT])
val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Try(ExpressionEncoder[A3]()).toOption :: Try(ExpressionEncoder[A4]()).toOption :: Try(ExpressionEncoder[A5]()).toOption :: Try(ExpressionEncoder[A6]()).toOption :: Try(ExpressionEncoder[A7]()).toOption :: Try(ExpressionEncoder[A8]()).toOption :: Try(ExpressionEncoder[A9]()).toOption :: Nil
val udf = SparkUserDefinedFunction(func, dataType, inputEncoders, outputEncoder).withName(name)
val finalUdf = if (nullable) udf else udf.asNonNullable()
@@ -393,7 +395,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
*/
def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag](name: String, func: Function10[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, RT]): UserDefinedFunction = {
val outputEncoder = Try(ExpressionEncoder[RT]()).toOption
- val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(_.dataTypeAndNullable).getOrElse(ScalaReflection.schemaFor[RT])
+ val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(outputSchema).getOrElse(ScalaReflection.schemaFor[RT])
val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Try(ExpressionEncoder[A3]()).toOption :: Try(ExpressionEncoder[A4]()).toOption :: Try(ExpressionEncoder[A5]()).toOption :: Try(ExpressionEncoder[A6]()).toOption :: Try(ExpressionEncoder[A7]()).toOption :: Try(ExpressionEncoder[A8]()).toOption :: Try(ExpressionEncoder[A9]()).toOption :: Try(ExpressionEncoder[A10]()).toOption :: Nil
val udf = SparkUserDefinedFunction(func, dataType, inputEncoders, outputEncoder).withName(name)
val finalUdf = if (nullable) udf else udf.asNonNullable()
@@ -414,7 +416,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
*/
def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag](name: String, func: Function11[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, RT]): UserDefinedFunction = {
val outputEncoder = Try(ExpressionEncoder[RT]()).toOption
- val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(_.dataTypeAndNullable).getOrElse(ScalaReflection.schemaFor[RT])
+ val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(outputSchema).getOrElse(ScalaReflection.schemaFor[RT])
val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Try(ExpressionEncoder[A3]()).toOption :: Try(ExpressionEncoder[A4]()).toOption :: Try(ExpressionEncoder[A5]()).toOption :: Try(ExpressionEncoder[A6]()).toOption :: Try(ExpressionEncoder[A7]()).toOption :: Try(ExpressionEncoder[A8]()).toOption :: Try(ExpressionEncoder[A9]()).toOption :: Try(ExpressionEncoder[A10]()).toOption :: Try(ExpressionEncoder[ [...]
val udf = SparkUserDefinedFunction(func, dataType, inputEncoders, outputEncoder).withName(name)
val finalUdf = if (nullable) udf else udf.asNonNullable()
@@ -435,7 +437,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
*/
def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag](name: String, func: Function12[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, RT]): UserDefinedFunction = {
val outputEncoder = Try(ExpressionEncoder[RT]()).toOption
- val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(_.dataTypeAndNullable).getOrElse(ScalaReflection.schemaFor[RT])
+ val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(outputSchema).getOrElse(ScalaReflection.schemaFor[RT])
val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Try(ExpressionEncoder[A3]()).toOption :: Try(ExpressionEncoder[A4]()).toOption :: Try(ExpressionEncoder[A5]()).toOption :: Try(ExpressionEncoder[A6]()).toOption :: Try(ExpressionEncoder[A7]()).toOption :: Try(ExpressionEncoder[A8]()).toOption :: Try(ExpressionEncoder[A9]()).toOption :: Try(ExpressionEncoder[A10]()).toOption :: Try(ExpressionEncoder[ [...]
val udf = SparkUserDefinedFunction(func, dataType, inputEncoders, outputEncoder).withName(name)
val finalUdf = if (nullable) udf else udf.asNonNullable()
@@ -456,7 +458,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
*/
def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag](name: String, func: Function13[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, RT]): UserDefinedFunction = {
val outputEncoder = Try(ExpressionEncoder[RT]()).toOption
- val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(_.dataTypeAndNullable).getOrElse(ScalaReflection.schemaFor[RT])
+ val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(outputSchema).getOrElse(ScalaReflection.schemaFor[RT])
val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Try(ExpressionEncoder[A3]()).toOption :: Try(ExpressionEncoder[A4]()).toOption :: Try(ExpressionEncoder[A5]()).toOption :: Try(ExpressionEncoder[A6]()).toOption :: Try(ExpressionEncoder[A7]()).toOption :: Try(ExpressionEncoder[A8]()).toOption :: Try(ExpressionEncoder[A9]()).toOption :: Try(ExpressionEncoder[A10]()).toOption :: Try(ExpressionEncoder[ [...]
val udf = SparkUserDefinedFunction(func, dataType, inputEncoders, outputEncoder).withName(name)
val finalUdf = if (nullable) udf else udf.asNonNullable()
@@ -477,7 +479,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
*/
def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag](name: String, func: Function14[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, RT]): UserDefinedFunction = {
val outputEncoder = Try(ExpressionEncoder[RT]()).toOption
- val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(_.dataTypeAndNullable).getOrElse(ScalaReflection.schemaFor[RT])
+ val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(outputSchema).getOrElse(ScalaReflection.schemaFor[RT])
val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Try(ExpressionEncoder[A3]()).toOption :: Try(ExpressionEncoder[A4]()).toOption :: Try(ExpressionEncoder[A5]()).toOption :: Try(ExpressionEncoder[A6]()).toOption :: Try(ExpressionEncoder[A7]()).toOption :: Try(ExpressionEncoder[A8]()).toOption :: Try(ExpressionEncoder[A9]()).toOption :: Try(ExpressionEncoder[A10]()).toOption :: Try(ExpressionEncoder[ [...]
val udf = SparkUserDefinedFunction(func, dataType, inputEncoders, outputEncoder).withName(name)
val finalUdf = if (nullable) udf else udf.asNonNullable()
@@ -498,7 +500,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
*/
def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag](name: String, func: Function15[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, RT]): UserDefinedFunction = {
val outputEncoder = Try(ExpressionEncoder[RT]()).toOption
- val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(_.dataTypeAndNullable).getOrElse(ScalaReflection.schemaFor[RT])
+ val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(outputSchema).getOrElse(ScalaReflection.schemaFor[RT])
val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Try(ExpressionEncoder[A3]()).toOption :: Try(ExpressionEncoder[A4]()).toOption :: Try(ExpressionEncoder[A5]()).toOption :: Try(ExpressionEncoder[A6]()).toOption :: Try(ExpressionEncoder[A7]()).toOption :: Try(ExpressionEncoder[A8]()).toOption :: Try(ExpressionEncoder[A9]()).toOption :: Try(ExpressionEncoder[A10]()).toOption :: Try(ExpressionEncoder[ [...]
val udf = SparkUserDefinedFunction(func, dataType, inputEncoders, outputEncoder).withName(name)
val finalUdf = if (nullable) udf else udf.asNonNullable()
@@ -519,7 +521,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
*/
def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag](name: String, func: Function16[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, RT]): UserDefinedFunction = {
val outputEncoder = Try(ExpressionEncoder[RT]()).toOption
- val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(_.dataTypeAndNullable).getOrElse(ScalaReflection.schemaFor[RT])
+ val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(outputSchema).getOrElse(ScalaReflection.schemaFor[RT])
val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Try(ExpressionEncoder[A3]()).toOption :: Try(ExpressionEncoder[A4]()).toOption :: Try(ExpressionEncoder[A5]()).toOption :: Try(ExpressionEncoder[A6]()).toOption :: Try(ExpressionEncoder[A7]()).toOption :: Try(ExpressionEncoder[A8]()).toOption :: Try(ExpressionEncoder[A9]()).toOption :: Try(ExpressionEncoder[A10]()).toOption :: Try(ExpressionEncoder[ [...]
val udf = SparkUserDefinedFunction(func, dataType, inputEncoders, outputEncoder).withName(name)
val finalUdf = if (nullable) udf else udf.asNonNullable()
@@ -540,7 +542,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
*/
def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag](name: String, func: Function17[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, RT]): UserDefinedFunction = {
val outputEncoder = Try(ExpressionEncoder[RT]()).toOption
- val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(_.dataTypeAndNullable).getOrElse(ScalaReflection.schemaFor[RT])
+ val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(outputSchema).getOrElse(ScalaReflection.schemaFor[RT])
val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Try(ExpressionEncoder[A3]()).toOption :: Try(ExpressionEncoder[A4]()).toOption :: Try(ExpressionEncoder[A5]()).toOption :: Try(ExpressionEncoder[A6]()).toOption :: Try(ExpressionEncoder[A7]()).toOption :: Try(ExpressionEncoder[A8]()).toOption :: Try(ExpressionEncoder[A9]()).toOption :: Try(ExpressionEncoder[A10]()).toOption :: Try(ExpressionEncoder[ [...]
val udf = SparkUserDefinedFunction(func, dataType, inputEncoders, outputEncoder).withName(name)
val finalUdf = if (nullable) udf else udf.asNonNullable()
@@ -561,7 +563,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
*/
def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag](name: String, func: Function18[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, RT]): UserDefinedFunction = {
val outputEncoder = Try(ExpressionEncoder[RT]()).toOption
- val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(_.dataTypeAndNullable).getOrElse(ScalaReflection.schemaFor[RT])
+ val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(outputSchema).getOrElse(ScalaReflection.schemaFor[RT])
val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Try(ExpressionEncoder[A3]()).toOption :: Try(ExpressionEncoder[A4]()).toOption :: Try(ExpressionEncoder[A5]()).toOption :: Try(ExpressionEncoder[A6]()).toOption :: Try(ExpressionEncoder[A7]()).toOption :: Try(ExpressionEncoder[A8]()).toOption :: Try(ExpressionEncoder[A9]()).toOption :: Try(ExpressionEncoder[A10]()).toOption :: Try(ExpressionEncoder[ [...]
val udf = SparkUserDefinedFunction(func, dataType, inputEncoders, outputEncoder).withName(name)
val finalUdf = if (nullable) udf else udf.asNonNullable()
@@ -582,7 +584,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
*/
def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag, A19: TypeTag](name: String, func: Function19[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, A19, RT]): UserDefinedFunction = {
val outputEncoder = Try(ExpressionEncoder[RT]()).toOption
- val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(_.dataTypeAndNullable).getOrElse(ScalaReflection.schemaFor[RT])
+ val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(outputSchema).getOrElse(ScalaReflection.schemaFor[RT])
val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Try(ExpressionEncoder[A3]()).toOption :: Try(ExpressionEncoder[A4]()).toOption :: Try(ExpressionEncoder[A5]()).toOption :: Try(ExpressionEncoder[A6]()).toOption :: Try(ExpressionEncoder[A7]()).toOption :: Try(ExpressionEncoder[A8]()).toOption :: Try(ExpressionEncoder[A9]()).toOption :: Try(ExpressionEncoder[A10]()).toOption :: Try(ExpressionEncoder[ [...]
val udf = SparkUserDefinedFunction(func, dataType, inputEncoders, outputEncoder).withName(name)
val finalUdf = if (nullable) udf else udf.asNonNullable()
@@ -603,7 +605,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
*/
def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag, A19: TypeTag, A20: TypeTag](name: String, func: Function20[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, A19, A20, RT]): UserDefinedFunction = {
val outputEncoder = Try(ExpressionEncoder[RT]()).toOption
- val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(_.dataTypeAndNullable).getOrElse(ScalaReflection.schemaFor[RT])
+ val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(outputSchema).getOrElse(ScalaReflection.schemaFor[RT])
val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Try(ExpressionEncoder[A3]()).toOption :: Try(ExpressionEncoder[A4]()).toOption :: Try(ExpressionEncoder[A5]()).toOption :: Try(ExpressionEncoder[A6]()).toOption :: Try(ExpressionEncoder[A7]()).toOption :: Try(ExpressionEncoder[A8]()).toOption :: Try(ExpressionEncoder[A9]()).toOption :: Try(ExpressionEncoder[A10]()).toOption :: Try(ExpressionEncoder[ [...]
val udf = SparkUserDefinedFunction(func, dataType, inputEncoders, outputEncoder).withName(name)
val finalUdf = if (nullable) udf else udf.asNonNullable()
@@ -624,7 +626,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
*/
def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag, A19: TypeTag, A20: TypeTag, A21: TypeTag](name: String, func: Function21[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, A19, A20, A21, RT]): UserDefinedFunction = {
val outputEncoder = Try(ExpressionEncoder[RT]()).toOption
- val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(_.dataTypeAndNullable).getOrElse(ScalaReflection.schemaFor[RT])
+ val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(outputSchema).getOrElse(ScalaReflection.schemaFor[RT])
val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Try(ExpressionEncoder[A3]()).toOption :: Try(ExpressionEncoder[A4]()).toOption :: Try(ExpressionEncoder[A5]()).toOption :: Try(ExpressionEncoder[A6]()).toOption :: Try(ExpressionEncoder[A7]()).toOption :: Try(ExpressionEncoder[A8]()).toOption :: Try(ExpressionEncoder[A9]()).toOption :: Try(ExpressionEncoder[A10]()).toOption :: Try(ExpressionEncoder[ [...]
val udf = SparkUserDefinedFunction(func, dataType, inputEncoders, outputEncoder).withName(name)
val finalUdf = if (nullable) udf else udf.asNonNullable()
@@ -645,7 +647,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
*/
def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag, A19: TypeTag, A20: TypeTag, A21: TypeTag, A22: TypeTag](name: String, func: Function22[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, A19, A20, A21, A22, RT]): UserDefinedFunction = {
val outputEncoder = Try(ExpressionEncoder[RT]()).toOption
- val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(_.dataTypeAndNullable).getOrElse(ScalaReflection.schemaFor[RT])
+ val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(outputSchema).getOrElse(ScalaReflection.schemaFor[RT])
val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Try(ExpressionEncoder[A3]()).toOption :: Try(ExpressionEncoder[A4]()).toOption :: Try(ExpressionEncoder[A5]()).toOption :: Try(ExpressionEncoder[A6]()).toOption :: Try(ExpressionEncoder[A7]()).toOption :: Try(ExpressionEncoder[A8]()).toOption :: Try(ExpressionEncoder[A9]()).toOption :: Try(ExpressionEncoder[A10]()).toOption :: Try(ExpressionEncoder[ [...]
val udf = SparkUserDefinedFunction(func, dataType, inputEncoders, outputEncoder).withName(name)
val finalUdf = if (nullable) udf else udf.asNonNullable()
@@ -1121,3 +1123,17 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
// scalastyle:on line.size.limit
}
+
+private[sql] object UDFRegistration {
+ /**
+ * Obtaining the schema of output encoder for `ScalaUDF`.
+ *
+ * As the serialization in `ScalaUDF` is for individual column, not the whole row,
+ * we just take the data type of vanilla object serializer, not `serializer` which
+ * is transformed somehow for top-level row.
+ */
+ def outputSchema(outputEncoder: ExpressionEncoder[_]): ScalaReflection.Schema = {
+ ScalaReflection.Schema(outputEncoder.objSerializer.dataType,
+ outputEncoder.objSerializer.nullable)
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index 764e088..ff3f568 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -4603,7 +4603,7 @@ object functions {
| */
|def udf[$typeTags](f: Function$x[$types]): UserDefinedFunction = {
| val outputEncoder = Try(ExpressionEncoder[RT]()).toOption
- | val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(_.dataTypeAndNullable).getOrElse(ScalaReflection.schemaFor[RT])
+ | val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(UDFRegistration.outputSchema).getOrElse(ScalaReflection.schemaFor[RT])
| val inputEncoders = $inputEncoders
| val udf = SparkUserDefinedFunction(f, dataType, inputEncoders, outputEncoder)
| if (nullable) udf else udf.asNonNullable()
@@ -4710,7 +4710,7 @@ object functions {
*/
def udf[RT: TypeTag](f: Function0[RT]): UserDefinedFunction = {
val outputEncoder = Try(ExpressionEncoder[RT]()).toOption
- val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(_.dataTypeAndNullable).getOrElse(ScalaReflection.schemaFor[RT])
+ val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(UDFRegistration.outputSchema).getOrElse(ScalaReflection.schemaFor[RT])
val inputEncoders = Nil
val udf = SparkUserDefinedFunction(f, dataType, inputEncoders, outputEncoder)
if (nullable) udf else udf.asNonNullable()
@@ -4727,7 +4727,7 @@ object functions {
*/
def udf[RT: TypeTag, A1: TypeTag](f: Function1[A1, RT]): UserDefinedFunction = {
val outputEncoder = Try(ExpressionEncoder[RT]()).toOption
- val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(_.dataTypeAndNullable).getOrElse(ScalaReflection.schemaFor[RT])
+ val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(UDFRegistration.outputSchema).getOrElse(ScalaReflection.schemaFor[RT])
val inputEncoders = Try(ExpressionEncoder[A1]()).toOption :: Nil
val udf = SparkUserDefinedFunction(f, dataType, inputEncoders, outputEncoder)
if (nullable) udf else udf.asNonNullable()
@@ -4744,7 +4744,7 @@ object functions {
*/
def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag](f: Function2[A1, A2, RT]): UserDefinedFunction = {
val outputEncoder = Try(ExpressionEncoder[RT]()).toOption
- val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(_.dataTypeAndNullable).getOrElse(ScalaReflection.schemaFor[RT])
+ val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(UDFRegistration.outputSchema).getOrElse(ScalaReflection.schemaFor[RT])
val inputEncoders = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Nil
val udf = SparkUserDefinedFunction(f, dataType, inputEncoders, outputEncoder)
if (nullable) udf else udf.asNonNullable()
@@ -4761,7 +4761,7 @@ object functions {
*/
def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag](f: Function3[A1, A2, A3, RT]): UserDefinedFunction = {
val outputEncoder = Try(ExpressionEncoder[RT]()).toOption
- val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(_.dataTypeAndNullable).getOrElse(ScalaReflection.schemaFor[RT])
+ val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(UDFRegistration.outputSchema).getOrElse(ScalaReflection.schemaFor[RT])
val inputEncoders = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Try(ExpressionEncoder[A3]()).toOption :: Nil
val udf = SparkUserDefinedFunction(f, dataType, inputEncoders, outputEncoder)
if (nullable) udf else udf.asNonNullable()
@@ -4778,7 +4778,7 @@ object functions {
*/
def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag](f: Function4[A1, A2, A3, A4, RT]): UserDefinedFunction = {
val outputEncoder = Try(ExpressionEncoder[RT]()).toOption
- val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(_.dataTypeAndNullable).getOrElse(ScalaReflection.schemaFor[RT])
+ val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(UDFRegistration.outputSchema).getOrElse(ScalaReflection.schemaFor[RT])
val inputEncoders = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Try(ExpressionEncoder[A3]()).toOption :: Try(ExpressionEncoder[A4]()).toOption :: Nil
val udf = SparkUserDefinedFunction(f, dataType, inputEncoders, outputEncoder)
if (nullable) udf else udf.asNonNullable()
@@ -4795,7 +4795,7 @@ object functions {
*/
def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag](f: Function5[A1, A2, A3, A4, A5, RT]): UserDefinedFunction = {
val outputEncoder = Try(ExpressionEncoder[RT]()).toOption
- val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(_.dataTypeAndNullable).getOrElse(ScalaReflection.schemaFor[RT])
+ val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(UDFRegistration.outputSchema).getOrElse(ScalaReflection.schemaFor[RT])
val inputEncoders = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Try(ExpressionEncoder[A3]()).toOption :: Try(ExpressionEncoder[A4]()).toOption :: Try(ExpressionEncoder[A5]()).toOption :: Nil
val udf = SparkUserDefinedFunction(f, dataType, inputEncoders, outputEncoder)
if (nullable) udf else udf.asNonNullable()
@@ -4812,7 +4812,7 @@ object functions {
*/
def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag](f: Function6[A1, A2, A3, A4, A5, A6, RT]): UserDefinedFunction = {
val outputEncoder = Try(ExpressionEncoder[RT]()).toOption
- val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(_.dataTypeAndNullable).getOrElse(ScalaReflection.schemaFor[RT])
+ val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(UDFRegistration.outputSchema).getOrElse(ScalaReflection.schemaFor[RT])
val inputEncoders = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Try(ExpressionEncoder[A3]()).toOption :: Try(ExpressionEncoder[A4]()).toOption :: Try(ExpressionEncoder[A5]()).toOption :: Try(ExpressionEncoder[A6]()).toOption :: Nil
val udf = SparkUserDefinedFunction(f, dataType, inputEncoders, outputEncoder)
if (nullable) udf else udf.asNonNullable()
@@ -4829,7 +4829,7 @@ object functions {
*/
def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag](f: Function7[A1, A2, A3, A4, A5, A6, A7, RT]): UserDefinedFunction = {
val outputEncoder = Try(ExpressionEncoder[RT]()).toOption
- val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(_.dataTypeAndNullable).getOrElse(ScalaReflection.schemaFor[RT])
+ val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(UDFRegistration.outputSchema).getOrElse(ScalaReflection.schemaFor[RT])
val inputEncoders = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Try(ExpressionEncoder[A3]()).toOption :: Try(ExpressionEncoder[A4]()).toOption :: Try(ExpressionEncoder[A5]()).toOption :: Try(ExpressionEncoder[A6]()).toOption :: Try(ExpressionEncoder[A7]()).toOption :: Nil
val udf = SparkUserDefinedFunction(f, dataType, inputEncoders, outputEncoder)
if (nullable) udf else udf.asNonNullable()
@@ -4846,7 +4846,7 @@ object functions {
*/
def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag](f: Function8[A1, A2, A3, A4, A5, A6, A7, A8, RT]): UserDefinedFunction = {
val outputEncoder = Try(ExpressionEncoder[RT]()).toOption
- val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(_.dataTypeAndNullable).getOrElse(ScalaReflection.schemaFor[RT])
+ val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(UDFRegistration.outputSchema).getOrElse(ScalaReflection.schemaFor[RT])
val inputEncoders = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Try(ExpressionEncoder[A3]()).toOption :: Try(ExpressionEncoder[A4]()).toOption :: Try(ExpressionEncoder[A5]()).toOption :: Try(ExpressionEncoder[A6]()).toOption :: Try(ExpressionEncoder[A7]()).toOption :: Try(ExpressionEncoder[A8]()).toOption :: Nil
val udf = SparkUserDefinedFunction(f, dataType, inputEncoders, outputEncoder)
if (nullable) udf else udf.asNonNullable()
@@ -4863,7 +4863,7 @@ object functions {
*/
def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag](f: Function9[A1, A2, A3, A4, A5, A6, A7, A8, A9, RT]): UserDefinedFunction = {
val outputEncoder = Try(ExpressionEncoder[RT]()).toOption
- val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(_.dataTypeAndNullable).getOrElse(ScalaReflection.schemaFor[RT])
+ val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(UDFRegistration.outputSchema).getOrElse(ScalaReflection.schemaFor[RT])
val inputEncoders = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Try(ExpressionEncoder[A3]()).toOption :: Try(ExpressionEncoder[A4]()).toOption :: Try(ExpressionEncoder[A5]()).toOption :: Try(ExpressionEncoder[A6]()).toOption :: Try(ExpressionEncoder[A7]()).toOption :: Try(ExpressionEncoder[A8]()).toOption :: Try(ExpressionEncoder[A9]()).toOption :: Nil
val udf = SparkUserDefinedFunction(f, dataType, inputEncoders, outputEncoder)
if (nullable) udf else udf.asNonNullable()
@@ -4880,7 +4880,7 @@ object functions {
*/
def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag](f: Function10[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, RT]): UserDefinedFunction = {
val outputEncoder = Try(ExpressionEncoder[RT]()).toOption
- val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(_.dataTypeAndNullable).getOrElse(ScalaReflection.schemaFor[RT])
+ val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(UDFRegistration.outputSchema).getOrElse(ScalaReflection.schemaFor[RT])
val inputEncoders = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Try(ExpressionEncoder[A3]()).toOption :: Try(ExpressionEncoder[A4]()).toOption :: Try(ExpressionEncoder[A5]()).toOption :: Try(ExpressionEncoder[A6]()).toOption :: Try(ExpressionEncoder[A7]()).toOption :: Try(ExpressionEncoder[A8]()).toOption :: Try(ExpressionEncoder[A9]()).toOption :: Try(ExpressionEncoder[A10]()).toOption :: Nil
val udf = SparkUserDefinedFunction(f, dataType, inputEncoders, outputEncoder)
if (nullable) udf else udf.asNonNullable()
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
index 3a169e4..2ec4c69 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
@@ -1982,8 +1982,35 @@ class DatasetSuite extends QueryTest
assert(timezone == "Asia/Shanghai")
}
}
+
+ test("SPARK-34002: Fix broken Option input/output in UDF") {
+ def f1(bar: Bar): Option[Bar] = {
+ None
+ }
+
+ def f2(bar: Option[Bar]): Option[Bar] = {
+ bar
+ }
+
+ val udf1 = udf(f1 _).withName("f1")
+ val udf2 = udf(f2 _).withName("f2")
+
+ val df = (1 to 2).map(i => Tuple1(Bar(1))).toDF("c0")
+ val withUDF = df
+ .withColumn("c1", udf1(col("c0")))
+ .withColumn("c2", udf2(col("c1")))
+
+ assert(withUDF.schema == StructType(
+ StructField("c0", StructType(StructField("a", IntegerType, false) :: Nil)) ::
+ StructField("c1", StructType(StructField("a", IntegerType, false) :: Nil)) ::
+ StructField("c2", StructType(StructField("a", IntegerType, false) :: Nil)) :: Nil))
+
+ checkAnswer(withUDF, Row(Row(1), null, null) :: Row(Row(1), null, null) :: Nil)
+ }
}
+case class Bar(a: Int)
+
object AssertExecutionId {
def apply(id: Long): Long = {
assert(TaskContext.get().getLocalProperty(SQLExecution.EXECUTION_ID_KEY) != null)
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org