You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by li...@apache.org on 2017/01/25 16:17:31 UTC
spark git commit: [SPARK-19311][SQL] fix UDT hierarchy issue
Repository: spark
Updated Branches:
refs/heads/master f1ddca5fc -> f6480b146
[SPARK-19311][SQL] fix UDT hierarchy issue
## What changes were proposed in this pull request?
acceptType() in UDT will no only accept the same type but also all base types
## How was this patch tested?
Manual test using a set of generated UDTs fixing acceptType() in my user defined types
Please review http://spark.apache.org/contributing.html before opening a pull request.
Author: gmoehler <mo...@de.ibm.com>
Closes #16660 from gmoehler/master.
Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/f6480b14
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/f6480b14
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/f6480b14
Branch: refs/heads/master
Commit: f6480b1467d0432fb2aa48c7a3a8a6e6679fd481
Parents: f1ddca5
Author: gmoehler <mo...@de.ibm.com>
Authored: Wed Jan 25 08:17:24 2017 -0800
Committer: gatorsmile <ga...@gmail.com>
Committed: Wed Jan 25 08:17:24 2017 -0800
----------------------------------------------------------------------
.../spark/sql/types/UserDefinedType.scala | 8 +-
.../apache/spark/sql/UserDefinedTypeSuite.scala | 105 ++++++++++++++++++-
2 files changed, 110 insertions(+), 3 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/spark/blob/f6480b14/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala
index c33219c..5a944e7 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala
@@ -78,8 +78,12 @@ abstract class UserDefinedType[UserType >: Null] extends DataType with Serializa
*/
override private[spark] def asNullable: UserDefinedType[UserType] = this
- override private[sql] def acceptsType(dataType: DataType) =
- this.getClass == dataType.getClass
+ override private[sql] def acceptsType(dataType: DataType) = dataType match {
+ case other: UserDefinedType[_] =>
+ this.getClass == other.getClass ||
+ this.userClass.isAssignableFrom(other.userClass)
+ case _ => false
+ }
override def sql: String = sqlType.sql
http://git-wip-us.apache.org/repos/asf/spark/blob/f6480b14/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala
index 474f17f..ea4a8ee 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala
@@ -20,7 +20,8 @@ package org.apache.spark.sql
import scala.beans.{BeanInfo, BeanProperty}
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.catalyst.CatalystTypeConverters
+import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
+import org.apache.spark.sql.catalyst.expressions.GenericInternalRow
import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData}
import org.apache.spark.sql.execution.datasources.parquet.ParquetTest
import org.apache.spark.sql.functions._
@@ -71,6 +72,77 @@ object UDT {
}
+// object and classes to test SPARK-19311
+
+// Trait/Interface for base type
+sealed trait IExampleBaseType extends Serializable {
+ def field: Int
+}
+
+// Trait/Interface for derived type
+sealed trait IExampleSubType extends IExampleBaseType
+
+// a base class
+class ExampleBaseClass(override val field: Int) extends IExampleBaseType
+
+// a derived class
+class ExampleSubClass(override val field: Int)
+ extends ExampleBaseClass(field) with IExampleSubType
+
+// UDT for base class
+class ExampleBaseTypeUDT extends UserDefinedType[IExampleBaseType] {
+
+ override def sqlType: StructType = {
+ StructType(Seq(
+ StructField("intfield", IntegerType, nullable = false)))
+ }
+
+ override def serialize(obj: IExampleBaseType): InternalRow = {
+ val row = new GenericInternalRow(1)
+ row.setInt(0, obj.field)
+ row
+ }
+
+ override def deserialize(datum: Any): IExampleBaseType = {
+ datum match {
+ case row: InternalRow =>
+ require(row.numFields == 1,
+ "ExampleBaseTypeUDT requires row with length == 1")
+ val field = row.getInt(0)
+ new ExampleBaseClass(field)
+ }
+ }
+
+ override def userClass: Class[IExampleBaseType] = classOf[IExampleBaseType]
+}
+
+// UDT for derived class
+private[spark] class ExampleSubTypeUDT extends UserDefinedType[IExampleSubType] {
+
+ override def sqlType: StructType = {
+ StructType(Seq(
+ StructField("intfield", IntegerType, nullable = false)))
+ }
+
+ override def serialize(obj: IExampleSubType): InternalRow = {
+ val row = new GenericInternalRow(1)
+ row.setInt(0, obj.field)
+ row
+ }
+
+ override def deserialize(datum: Any): IExampleSubType = {
+ datum match {
+ case row: InternalRow =>
+ require(row.numFields == 1,
+ "ExampleSubTypeUDT requires row with length == 1")
+ val field = row.getInt(0)
+ new ExampleSubClass(field)
+ }
+ }
+
+ override def userClass: Class[IExampleSubType] = classOf[IExampleSubType]
+}
+
class UserDefinedTypeSuite extends QueryTest with SharedSQLContext with ParquetTest {
import testImplicits._
@@ -194,4 +266,35 @@ class UserDefinedTypeSuite extends QueryTest with SharedSQLContext with ParquetT
// call `collect` to make sure this query can pass analysis.
pointsRDD.as[MyLabeledPoint].map(_.copy(label = 2.0)).collect()
}
+
+ test("SPARK-19311: UDFs disregard UDT type hierarchy") {
+ UDTRegistration.register(classOf[IExampleBaseType].getName,
+ classOf[ExampleBaseTypeUDT].getName)
+ UDTRegistration.register(classOf[IExampleSubType].getName,
+ classOf[ExampleSubTypeUDT].getName)
+
+ // UDF that returns a base class object
+ sqlContext.udf.register("doUDF", (param: Int) => {
+ new ExampleBaseClass(param)
+ }: IExampleBaseType)
+
+ // UDF that returns a derived class object
+ sqlContext.udf.register("doSubTypeUDF", (param: Int) => {
+ new ExampleSubClass(param)
+ }: IExampleSubType)
+
+ // UDF that takes a base class object as parameter
+ sqlContext.udf.register("doOtherUDF", (obj: IExampleBaseType) => {
+ obj.field
+ }: Int)
+
+ // this worked already before the fix SPARK-19311:
+ // return type of doUDF equals parameter type of doOtherUDF
+ sql("SELECT doOtherUDF(doUDF(41))")
+
+ // this one passes only with the fix SPARK-19311:
+ // return type of doSubUDF is a subtype of the parameter type of doOtherUDF
+ sql("SELECT doOtherUDF(doSubTypeUDF(42))")
+ }
+
}
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org