You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by me...@apache.org on 2016/03/17 00:59:40 UTC

spark git commit: [SPARK-11011][SQL] Narrow type of UDT serialization

Repository: spark
Updated Branches:
  refs/heads/master 77ba3021c -> d4d84936f


[SPARK-11011][SQL] Narrow type of UDT serialization

## What changes were proposed in this pull request?

Narrow down the parameter type of `UserDefinedType#serialize()`. Currently, the parameter type is `Any`, however it would logically make more sense to narrow it down to the type of the actual user defined type.

## How was this patch tested?

Existing tests were successfully run on local machine.

Author: Jakob Odersky <ja...@odersky.com>

Closes #11379 from jodersky/SPARK-11011-udt-types.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/d4d84936
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/d4d84936
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/d4d84936

Branch: refs/heads/master
Commit: d4d84936fb82bee91f4b04608de9f75c293ccc9e
Parents: 77ba302
Author: Jakob Odersky <ja...@odersky.com>
Authored: Wed Mar 16 16:59:36 2016 -0700
Committer: Xiangrui Meng <me...@databricks.com>
Committed: Wed Mar 16 16:59:36 2016 -0700

----------------------------------------------------------------------
 .../org/apache/spark/mllib/linalg/Matrices.scala   |  2 +-
 .../org/apache/spark/mllib/linalg/Vectors.scala    |  2 +-
 project/MimaExcludes.scala                         |  2 ++
 .../sql/catalyst/CatalystTypeConverters.scala      | 10 +++++-----
 .../apache/spark/sql/types/UserDefinedType.scala   |  7 ++-----
 .../sql/catalyst/analysis/AnalysisErrorSuite.scala | 17 +++++------------
 .../sql/catalyst/encoders/RowEncoderSuite.scala    | 13 +++++--------
 .../apache/spark/sql/test/ExamplePointUDT.scala    | 13 +++++--------
 .../apache/spark/sql/UserDefinedTypeSuite.scala    |  7 ++-----
 .../datasources/parquet/ParquetQuerySuite.scala    | 11 ++++-------
 10 files changed, 32 insertions(+), 52 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/d4d84936/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala
index fdede2a..157f2db 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala
@@ -177,7 +177,7 @@ private[spark] class MatrixUDT extends UserDefinedType[Matrix] {
       ))
   }
 
-  override def serialize(obj: Any): InternalRow = {
+  override def serialize(obj: Matrix): InternalRow = {
     val row = new GenericMutableRow(7)
     obj match {
       case sm: SparseMatrix =>

http://git-wip-us.apache.org/repos/asf/spark/blob/d4d84936/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
index cecfd06..0f0c3a2 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
@@ -203,7 +203,7 @@ class VectorUDT extends UserDefinedType[Vector] {
       StructField("values", ArrayType(DoubleType, containsNull = false), nullable = true)))
   }
 
-  override def serialize(obj: Any): InternalRow = {
+  override def serialize(obj: Vector): InternalRow = {
     obj match {
       case SparseVector(size, indices, values) =>
         val row = new GenericMutableRow(4)

http://git-wip-us.apache.org/repos/asf/spark/blob/d4d84936/project/MimaExcludes.scala
----------------------------------------------------------------------
diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala
index 59c7e7d..ffc6fa0 100644
--- a/project/MimaExcludes.scala
+++ b/project/MimaExcludes.scala
@@ -292,6 +292,8 @@ object MimaExcludes {
         ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.SQLConf$"),
         ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.SQLConf$SQLConfEntry$")
       ) ++ Seq(
+        //SPARK-11011 UserDefinedType serialization should be strongly typed
+        ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.mllib.linalg.VectorUDT.serialize"),
         // SPARK-12073: backpressure rate controller consumes events preferentially from lagging partitions
         ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.kafka.KafkaTestUtils.createTopic"),
         ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.kafka.DirectKafkaInputDStream.maxMessagesPerPartition")

http://git-wip-us.apache.org/repos/asf/spark/blob/d4d84936/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala
index 2ec0ff5..9bfc381 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala
@@ -136,16 +136,16 @@ object CatalystTypeConverters {
     override def toScalaImpl(row: InternalRow, column: Int): Any = row.get(column, dataType)
   }
 
-  private case class UDTConverter(
-      udt: UserDefinedType[_]) extends CatalystTypeConverter[Any, Any, Any] {
+  private case class UDTConverter[A >: Null](
+      udt: UserDefinedType[A]) extends CatalystTypeConverter[A, A, Any] {
     // toCatalyst (it calls toCatalystImpl) will do null check.
-    override def toCatalystImpl(scalaValue: Any): Any = udt.serialize(scalaValue)
+    override def toCatalystImpl(scalaValue: A): Any = udt.serialize(scalaValue)
 
-    override def toScala(catalystValue: Any): Any = {
+    override def toScala(catalystValue: Any): A = {
       if (catalystValue == null) null else udt.deserialize(catalystValue)
     }
 
-    override def toScalaImpl(row: InternalRow, column: Int): Any =
+    override def toScalaImpl(row: InternalRow, column: Int): A =
       toScala(row.get(column, udt.sqlType))
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/d4d84936/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala
index 9d2449f..dabf9a2 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala
@@ -37,7 +37,7 @@ import org.apache.spark.annotation.DeveloperApi
  * The conversion via `deserialize` occurs when reading from a `DataFrame`.
  */
 @DeveloperApi
-abstract class UserDefinedType[UserType] extends DataType with Serializable {
+abstract class UserDefinedType[UserType >: Null] extends DataType with Serializable {
 
   /** Underlying storage type for this UDT */
   def sqlType: DataType
@@ -50,11 +50,8 @@ abstract class UserDefinedType[UserType] extends DataType with Serializable {
 
   /**
    * Convert the user type to a SQL datum
-   *
-   * TODO: Can we make this take obj: UserType?  The issue is in
-   *       CatalystTypeConverters.convertToCatalyst, where we need to convert Any to UserType.
    */
-  def serialize(obj: Any): Any
+  def serialize(obj: UserType): Any
 
   /** Convert a SQL datum to the user type */
   def deserialize(datum: Any): UserType

http://git-wip-us.apache.org/repos/asf/spark/blob/d4d84936/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala
index 4e7bbc3..1b29752 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala
@@ -36,11 +36,7 @@ private[sql] class GroupableUDT extends UserDefinedType[GroupableData] {
 
   override def sqlType: DataType = IntegerType
 
-  override def serialize(obj: Any): Int = {
-    obj match {
-      case groupableData: GroupableData => groupableData.data
-    }
-  }
+  override def serialize(groupableData: GroupableData): Int = groupableData.data
 
   override def deserialize(datum: Any): GroupableData = {
     datum match {
@@ -60,13 +56,10 @@ private[sql] class UngroupableUDT extends UserDefinedType[UngroupableData] {
 
   override def sqlType: DataType = MapType(IntegerType, IntegerType)
 
-  override def serialize(obj: Any): MapData = {
-    obj match {
-      case groupableData: UngroupableData =>
-        val keyArray = new GenericArrayData(groupableData.data.keys.toSeq)
-        val valueArray = new GenericArrayData(groupableData.data.values.toSeq)
-        new ArrayBasedMapData(keyArray, valueArray)
-    }
+  override def serialize(ungroupableData: UngroupableData): MapData = {
+    val keyArray = new GenericArrayData(ungroupableData.data.keys.toSeq)
+    val valueArray = new GenericArrayData(ungroupableData.data.values.toSeq)
+    new ArrayBasedMapData(keyArray, valueArray)
   }
 
   override def deserialize(datum: Any): UngroupableData = {

http://git-wip-us.apache.org/repos/asf/spark/blob/d4d84936/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala
index f119c6f..bf0360c 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala
@@ -47,14 +47,11 @@ class ExamplePointUDT extends UserDefinedType[ExamplePoint] {
 
   override def pyUDT: String = "pyspark.sql.tests.ExamplePointUDT"
 
-  override def serialize(obj: Any): GenericArrayData = {
-    obj match {
-      case p: ExamplePoint =>
-        val output = new Array[Any](2)
-        output(0) = p.x
-        output(1) = p.y
-        new GenericArrayData(output)
-    }
+  override def serialize(p: ExamplePoint): GenericArrayData = {
+    val output = new Array[Any](2)
+    output(0) = p.x
+    output(1) = p.y
+    new GenericArrayData(output)
   }
 
   override def deserialize(datum: Any): ExamplePoint = {

http://git-wip-us.apache.org/repos/asf/spark/blob/d4d84936/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala b/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala
index e2c9fc4..695a5ad 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala
@@ -42,14 +42,11 @@ private[sql] class ExamplePointUDT extends UserDefinedType[ExamplePoint] {
 
   override def pyUDT: String = "pyspark.sql.tests.ExamplePointUDT"
 
-  override def serialize(obj: Any): GenericArrayData = {
-    obj match {
-      case p: ExamplePoint =>
-        val output = new Array[Any](2)
-        output(0) = p.x
-        output(1) = p.y
-        new GenericArrayData(output)
-    }
+  override def serialize(p: ExamplePoint): GenericArrayData = {
+    val output = new Array[Any](2)
+    output(0) = p.x
+    output(1) = p.y
+    new GenericArrayData(output)
   }
 
   override def deserialize(datum: Any): ExamplePoint = {

http://git-wip-us.apache.org/repos/asf/spark/blob/d4d84936/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala
index 9081bc7..8c4afb6 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala
@@ -45,11 +45,8 @@ private[sql] class MyDenseVectorUDT extends UserDefinedType[MyDenseVector] {
 
   override def sqlType: DataType = ArrayType(DoubleType, containsNull = false)
 
-  override def serialize(obj: Any): ArrayData = {
-    obj match {
-      case features: MyDenseVector =>
-        new GenericArrayData(features.data.map(_.asInstanceOf[Any]))
-    }
+  override def serialize(features: MyDenseVector): ArrayData = {
+    new GenericArrayData(features.data.map(_.asInstanceOf[Any]))
   }
 
   override def deserialize(datum: Any): MyDenseVector = {

http://git-wip-us.apache.org/repos/asf/spark/blob/d4d84936/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala
index fb99b0c..f8166c7 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala
@@ -590,14 +590,11 @@ object TestingUDT {
         .add("b", LongType, nullable = false)
         .add("c", DoubleType, nullable = false)
 
-    override def serialize(obj: Any): Any = {
+    override def serialize(n: NestedStruct): Any = {
       val row = new SpecificMutableRow(sqlType.asInstanceOf[StructType].map(_.dataType))
-      obj match {
-        case n: NestedStruct =>
-          row.setInt(0, n.a)
-          row.setLong(1, n.b)
-          row.setDouble(2, n.c)
-      }
+      row.setInt(0, n.a)
+      row.setLong(1, n.b)
+      row.setDouble(2, n.c)
     }
 
     override def userClass: Class[NestedStruct] = classOf[NestedStruct]


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org