You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2017/07/06 17:07:55 UTC

spark git commit: [SPARK-21204][SQL] Add support for Scala Set collection types in serialization

Repository: spark
Updated Branches:
  refs/heads/master 26ac085de -> 48e44b24a


[SPARK-21204][SQL] Add support for Scala Set collection types in serialization

## What changes were proposed in this pull request?

Currently we can't produce a `Dataset` containing `Set` in SparkSQL. This PR tries to support serialization/deserialization of `Set`.

Because there's no corresponding internal data type in SparkSQL for a `Set`, the most proper choice for serializing a set should be an array.

## How was this patch tested?

Added unit tests.

Author: Liang-Chi Hsieh <vi...@gmail.com>

Closes #18416 from viirya/SPARK-21204.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/48e44b24
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/48e44b24
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/48e44b24

Branch: refs/heads/master
Commit: 48e44b24a7663142176102ac4c6bf4242f103804
Parents: 26ac085
Author: Liang-Chi Hsieh <vi...@gmail.com>
Authored: Fri Jul 7 01:07:45 2017 +0800
Committer: Wenchen Fan <we...@databricks.com>
Committed: Fri Jul 7 01:07:45 2017 +0800

----------------------------------------------------------------------
 .../spark/sql/catalyst/ScalaReflection.scala    | 28 ++++++++++++++++--
 .../catalyst/expressions/objects/objects.scala  |  5 ++--
 .../org/apache/spark/sql/SQLImplicits.scala     | 10 +++++++
 .../spark/sql/DataFrameAggregateSuite.scala     | 10 +++++++
 .../spark/sql/DatasetPrimitiveSuite.scala       | 31 ++++++++++++++++++++
 5 files changed, 79 insertions(+), 5 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/48e44b24/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
index 814f2c1..4d5401f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
@@ -309,7 +309,10 @@ object ScalaReflection extends ScalaReflection {
           Invoke(arrayData, primitiveMethod, arrayCls, returnNullable = false)
         }
 
-      case t if t <:< localTypeOf[Seq[_]] =>
+      // We serialize a `Set` to Catalyst array. When we deserialize a Catalyst array
+      // to a `Set`, if there are duplicated elements, the elements will be de-duplicated.
+      case t if t <:< localTypeOf[Seq[_]] ||
+          t <:< localTypeOf[scala.collection.Set[_]] =>
         val TypeRef(_, _, Seq(elementType)) = t
         val Schema(dataType, elementNullable) = schemaFor(elementType)
         val className = getClassNameFromType(elementType)
@@ -327,8 +330,10 @@ object ScalaReflection extends ScalaReflection {
         }
 
         val companion = t.normalize.typeSymbol.companionSymbol.typeSignature
-        val cls = companion.declaration(newTermName("newBuilder")) match {
-          case NoSymbol => classOf[Seq[_]]
+        val cls = companion.member(newTermName("newBuilder")) match {
+          case NoSymbol if t <:< localTypeOf[Seq[_]] => classOf[Seq[_]]
+          case NoSymbol if t <:< localTypeOf[scala.collection.Set[_]] =>
+            classOf[scala.collection.Set[_]]
           case _ => mirror.runtimeClass(t.typeSymbol.asClass)
         }
         UnresolvedMapObjects(mapFunction, getPath, Some(cls))
@@ -502,6 +507,19 @@ object ScalaReflection extends ScalaReflection {
           serializerFor(_, valueType, valuePath, seenTypeSet),
           valueNullable = !valueType.typeSymbol.asClass.isPrimitive)
 
+      case t if t <:< localTypeOf[scala.collection.Set[_]] =>
+        val TypeRef(_, _, Seq(elementType)) = t
+
+        // There's no corresponding Catalyst type for `Set`, we serialize a `Set` to Catalyst array.
+        // Note that the property of `Set` is only kept when manipulating the data as domain object.
+        val newInput =
+          Invoke(
+           inputObject,
+           "toSeq",
+           ObjectType(classOf[Seq[_]]))
+
+        toCatalystArray(newInput, elementType)
+
       case t if t <:< localTypeOf[String] =>
         StaticInvoke(
           classOf[UTF8String],
@@ -713,6 +731,10 @@ object ScalaReflection extends ScalaReflection {
         val Schema(valueDataType, valueNullable) = schemaFor(valueType)
         Schema(MapType(schemaFor(keyType).dataType,
           valueDataType, valueContainsNull = valueNullable), nullable = true)
+      case t if t <:< localTypeOf[Set[_]] =>
+        val TypeRef(_, _, Seq(elementType)) = t
+        val Schema(dataType, nullable) = schemaFor(elementType)
+        Schema(ArrayType(dataType, containsNull = nullable), nullable = true)
       case t if t <:< localTypeOf[String] => Schema(StringType, nullable = true)
       case t if t <:< localTypeOf[java.sql.Timestamp] => Schema(TimestampType, nullable = true)
       case t if t <:< localTypeOf[java.sql.Date] => Schema(DateType, nullable = true)

http://git-wip-us.apache.org/repos/asf/spark/blob/48e44b24/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
index 24c06d8..9b28a18 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
@@ -627,8 +627,9 @@ case class MapObjects private(
 
     val (initCollection, addElement, getResult): (String, String => String, String) =
       customCollectionCls match {
-        case Some(cls) if classOf[Seq[_]].isAssignableFrom(cls) =>
-          // Scala sequence
+        case Some(cls) if classOf[Seq[_]].isAssignableFrom(cls) ||
+          classOf[scala.collection.Set[_]].isAssignableFrom(cls) =>
+          // Scala sequence or set
           val getBuilder = s"${cls.getName}$$.MODULE$$.newBuilder()"
           val builder = ctx.freshName("collectionBuilder")
           (

http://git-wip-us.apache.org/repos/asf/spark/blob/48e44b24/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala
index 86574e2..05db292 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala
@@ -171,6 +171,16 @@ abstract class SQLImplicits extends LowPrioritySQLImplicits {
   /** @since 2.3.0 */
   implicit def newMapEncoder[T <: Map[_, _] : TypeTag]: Encoder[T] = ExpressionEncoder()
 
+  /**
+   * Notice that we serialize `Set` to Catalyst array. The set property is only kept when
+   * manipulating the domain objects. The serialization format doesn't keep the set property.
+   * When we have a Catalyst array which contains duplicated elements and convert it to
+   * `Dataset[Set[T]]` by using the encoder, the elements will be de-duplicated.
+   *
+   * @since 2.3.0
+   */
+  implicit def newSetEncoder[T <: Set[_] : TypeTag]: Encoder[T] = ExpressionEncoder()
+
   // Arrays
 
   /** @since 1.6.1 */

http://git-wip-us.apache.org/repos/asf/spark/blob/48e44b24/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
index 5db354d..b52d50b 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
@@ -460,6 +460,16 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext {
       df.select(collect_set($"a"), collect_set($"b")),
       Seq(Row(Seq(1, 2, 3), Seq(2, 4)))
     )
+
+    checkDataset(
+      df.select(collect_set($"a").as("aSet")).as[Set[Int]],
+      Set(1, 2, 3))
+    checkDataset(
+      df.select(collect_set($"b").as("bSet")).as[Set[Int]],
+      Set(2, 4))
+    checkDataset(
+      df.select(collect_set($"a"), collect_set($"b")).as[(Set[Int], Set[Int])],
+      Seq(Set(1, 2, 3) -> Set(2, 4)): _*)
   }
 
   test("collect functions structs") {

http://git-wip-us.apache.org/repos/asf/spark/blob/48e44b24/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala
index a6847dc..f62f9e2 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala
@@ -17,6 +17,7 @@
 
 package org.apache.spark.sql
 
+import scala.collection.immutable.{HashSet => HSet}
 import scala.collection.immutable.Queue
 import scala.collection.mutable.{LinkedHashMap => LHMap}
 import scala.collection.mutable.ArrayBuffer
@@ -342,6 +343,31 @@ class DatasetPrimitiveSuite extends QueryTest with SharedSQLContext {
       LHMapClass(LHMap(1 -> 2)) -> LHMap("test" -> MapClass(Map(3 -> 4))))
   }
 
+  test("arbitrary sets") {
+    checkDataset(Seq(Set(1, 2, 3, 4)).toDS(), Set(1, 2, 3, 4))
+    checkDataset(Seq(Set(1.toLong, 2.toLong)).toDS(), Set(1.toLong, 2.toLong))
+    checkDataset(Seq(Set(1.toDouble, 2.toDouble)).toDS(), Set(1.toDouble, 2.toDouble))
+    checkDataset(Seq(Set(1.toFloat, 2.toFloat)).toDS(), Set(1.toFloat, 2.toFloat))
+    checkDataset(Seq(Set(1.toByte, 2.toByte)).toDS(), Set(1.toByte, 2.toByte))
+    checkDataset(Seq(Set(1.toShort, 2.toShort)).toDS(), Set(1.toShort, 2.toShort))
+    checkDataset(Seq(Set(true, false)).toDS(), Set(true, false))
+    checkDataset(Seq(Set("test1", "test2")).toDS(), Set("test1", "test2"))
+    checkDataset(Seq(Set(Tuple1(1), Tuple1(2))).toDS(), Set(Tuple1(1), Tuple1(2)))
+
+    checkDataset(Seq(HSet(1, 2)).toDS(), HSet(1, 2))
+    checkDataset(Seq(HSet(1.toLong, 2.toLong)).toDS(), HSet(1.toLong, 2.toLong))
+    checkDataset(Seq(HSet(1.toDouble, 2.toDouble)).toDS(), HSet(1.toDouble, 2.toDouble))
+    checkDataset(Seq(HSet(1.toFloat, 2.toFloat)).toDS(), HSet(1.toFloat, 2.toFloat))
+    checkDataset(Seq(HSet(1.toByte, 2.toByte)).toDS(), HSet(1.toByte, 2.toByte))
+    checkDataset(Seq(HSet(1.toShort, 2.toShort)).toDS(), HSet(1.toShort, 2.toShort))
+    checkDataset(Seq(HSet(true, false)).toDS(), HSet(true, false))
+    checkDataset(Seq(HSet("test1", "test2")).toDS(), HSet("test1", "test2"))
+    checkDataset(Seq(HSet(Tuple1(1), Tuple1(2))).toDS(), HSet(Tuple1(1), Tuple1(2)))
+
+    checkDataset(Seq(Seq(Some(1), None), Seq(Some(2))).toDF("c").as[Set[Integer]],
+      Seq(Set[Integer](1, null), Set[Integer](2)): _*)
+  }
+
   test("nested sequences") {
     checkDataset(Seq(Seq(Seq(1))).toDS(), Seq(Seq(1)))
     checkDataset(Seq(List(Queue(1))).toDS(), List(Queue(1)))
@@ -352,6 +378,11 @@ class DatasetPrimitiveSuite extends QueryTest with SharedSQLContext {
     checkDataset(Seq(LHMap(Map(1 -> 2) -> 3)).toDS(), LHMap(Map(1 -> 2) -> 3))
   }
 
+  test("nested set") {
+    checkDataset(Seq(Set(HSet(1, 2), HSet(3, 4))).toDS(), Set(HSet(1, 2), HSet(3, 4)))
+    checkDataset(Seq(HSet(Set(1, 2), Set(3, 4))).toDS(), HSet(Set(1, 2), Set(3, 4)))
+  }
+
   test("package objects") {
     import packageobject._
     checkDataset(Seq(PackageClass(1)).toDS(), PackageClass(1))


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org