You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2017/07/06 17:07:55 UTC
spark git commit: [SPARK-21204][SQL] Add support for Scala Set
collection types in serialization
Repository: spark
Updated Branches:
refs/heads/master 26ac085de -> 48e44b24a
[SPARK-21204][SQL] Add support for Scala Set collection types in serialization
## What changes were proposed in this pull request?
Currently we can't produce a `Dataset` containing `Set` in SparkSQL. This PR tries to support serialization/deserialization of `Set`.
Because there's no corresponding internal data type in SparkSQL for a `Set`, the most proper choice for serializing a set should be an array.
## How was this patch tested?
Added unit tests.
Author: Liang-Chi Hsieh <vi...@gmail.com>
Closes #18416 from viirya/SPARK-21204.
Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/48e44b24
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/48e44b24
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/48e44b24
Branch: refs/heads/master
Commit: 48e44b24a7663142176102ac4c6bf4242f103804
Parents: 26ac085
Author: Liang-Chi Hsieh <vi...@gmail.com>
Authored: Fri Jul 7 01:07:45 2017 +0800
Committer: Wenchen Fan <we...@databricks.com>
Committed: Fri Jul 7 01:07:45 2017 +0800
----------------------------------------------------------------------
.../spark/sql/catalyst/ScalaReflection.scala | 28 ++++++++++++++++--
.../catalyst/expressions/objects/objects.scala | 5 ++--
.../org/apache/spark/sql/SQLImplicits.scala | 10 +++++++
.../spark/sql/DataFrameAggregateSuite.scala | 10 +++++++
.../spark/sql/DatasetPrimitiveSuite.scala | 31 ++++++++++++++++++++
5 files changed, 79 insertions(+), 5 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/spark/blob/48e44b24/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
index 814f2c1..4d5401f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
@@ -309,7 +309,10 @@ object ScalaReflection extends ScalaReflection {
Invoke(arrayData, primitiveMethod, arrayCls, returnNullable = false)
}
- case t if t <:< localTypeOf[Seq[_]] =>
+ // We serialize a `Set` to Catalyst array. When we deserialize a Catalyst array
+ // to a `Set`, if there are duplicated elements, the elements will be de-duplicated.
+ case t if t <:< localTypeOf[Seq[_]] ||
+ t <:< localTypeOf[scala.collection.Set[_]] =>
val TypeRef(_, _, Seq(elementType)) = t
val Schema(dataType, elementNullable) = schemaFor(elementType)
val className = getClassNameFromType(elementType)
@@ -327,8 +330,10 @@ object ScalaReflection extends ScalaReflection {
}
val companion = t.normalize.typeSymbol.companionSymbol.typeSignature
- val cls = companion.declaration(newTermName("newBuilder")) match {
- case NoSymbol => classOf[Seq[_]]
+ val cls = companion.member(newTermName("newBuilder")) match {
+ case NoSymbol if t <:< localTypeOf[Seq[_]] => classOf[Seq[_]]
+ case NoSymbol if t <:< localTypeOf[scala.collection.Set[_]] =>
+ classOf[scala.collection.Set[_]]
case _ => mirror.runtimeClass(t.typeSymbol.asClass)
}
UnresolvedMapObjects(mapFunction, getPath, Some(cls))
@@ -502,6 +507,19 @@ object ScalaReflection extends ScalaReflection {
serializerFor(_, valueType, valuePath, seenTypeSet),
valueNullable = !valueType.typeSymbol.asClass.isPrimitive)
+ case t if t <:< localTypeOf[scala.collection.Set[_]] =>
+ val TypeRef(_, _, Seq(elementType)) = t
+
+ // There's no corresponding Catalyst type for `Set`, we serialize a `Set` to Catalyst array.
+ // Note that the property of `Set` is only kept when manipulating the data as domain object.
+ val newInput =
+ Invoke(
+ inputObject,
+ "toSeq",
+ ObjectType(classOf[Seq[_]]))
+
+ toCatalystArray(newInput, elementType)
+
case t if t <:< localTypeOf[String] =>
StaticInvoke(
classOf[UTF8String],
@@ -713,6 +731,10 @@ object ScalaReflection extends ScalaReflection {
val Schema(valueDataType, valueNullable) = schemaFor(valueType)
Schema(MapType(schemaFor(keyType).dataType,
valueDataType, valueContainsNull = valueNullable), nullable = true)
+ case t if t <:< localTypeOf[Set[_]] =>
+ val TypeRef(_, _, Seq(elementType)) = t
+ val Schema(dataType, nullable) = schemaFor(elementType)
+ Schema(ArrayType(dataType, containsNull = nullable), nullable = true)
case t if t <:< localTypeOf[String] => Schema(StringType, nullable = true)
case t if t <:< localTypeOf[java.sql.Timestamp] => Schema(TimestampType, nullable = true)
case t if t <:< localTypeOf[java.sql.Date] => Schema(DateType, nullable = true)
http://git-wip-us.apache.org/repos/asf/spark/blob/48e44b24/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
index 24c06d8..9b28a18 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
@@ -627,8 +627,9 @@ case class MapObjects private(
val (initCollection, addElement, getResult): (String, String => String, String) =
customCollectionCls match {
- case Some(cls) if classOf[Seq[_]].isAssignableFrom(cls) =>
- // Scala sequence
+ case Some(cls) if classOf[Seq[_]].isAssignableFrom(cls) ||
+ classOf[scala.collection.Set[_]].isAssignableFrom(cls) =>
+ // Scala sequence or set
val getBuilder = s"${cls.getName}$$.MODULE$$.newBuilder()"
val builder = ctx.freshName("collectionBuilder")
(
http://git-wip-us.apache.org/repos/asf/spark/blob/48e44b24/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala
index 86574e2..05db292 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala
@@ -171,6 +171,16 @@ abstract class SQLImplicits extends LowPrioritySQLImplicits {
/** @since 2.3.0 */
implicit def newMapEncoder[T <: Map[_, _] : TypeTag]: Encoder[T] = ExpressionEncoder()
+ /**
+ * Notice that we serialize `Set` to Catalyst array. The set property is only kept when
+ * manipulating the domain objects. The serialization format doesn't keep the set property.
+ * When we have a Catalyst array which contains duplicated elements and convert it to
+ * `Dataset[Set[T]]` by using the encoder, the elements will be de-duplicated.
+ *
+ * @since 2.3.0
+ */
+ implicit def newSetEncoder[T <: Set[_] : TypeTag]: Encoder[T] = ExpressionEncoder()
+
// Arrays
/** @since 1.6.1 */
http://git-wip-us.apache.org/repos/asf/spark/blob/48e44b24/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
index 5db354d..b52d50b 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
@@ -460,6 +460,16 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext {
df.select(collect_set($"a"), collect_set($"b")),
Seq(Row(Seq(1, 2, 3), Seq(2, 4)))
)
+
+ checkDataset(
+ df.select(collect_set($"a").as("aSet")).as[Set[Int]],
+ Set(1, 2, 3))
+ checkDataset(
+ df.select(collect_set($"b").as("bSet")).as[Set[Int]],
+ Set(2, 4))
+ checkDataset(
+ df.select(collect_set($"a"), collect_set($"b")).as[(Set[Int], Set[Int])],
+ Seq(Set(1, 2, 3) -> Set(2, 4)): _*)
}
test("collect functions structs") {
http://git-wip-us.apache.org/repos/asf/spark/blob/48e44b24/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala
index a6847dc..f62f9e2 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala
@@ -17,6 +17,7 @@
package org.apache.spark.sql
+import scala.collection.immutable.{HashSet => HSet}
import scala.collection.immutable.Queue
import scala.collection.mutable.{LinkedHashMap => LHMap}
import scala.collection.mutable.ArrayBuffer
@@ -342,6 +343,31 @@ class DatasetPrimitiveSuite extends QueryTest with SharedSQLContext {
LHMapClass(LHMap(1 -> 2)) -> LHMap("test" -> MapClass(Map(3 -> 4))))
}
+ test("arbitrary sets") {
+ checkDataset(Seq(Set(1, 2, 3, 4)).toDS(), Set(1, 2, 3, 4))
+ checkDataset(Seq(Set(1.toLong, 2.toLong)).toDS(), Set(1.toLong, 2.toLong))
+ checkDataset(Seq(Set(1.toDouble, 2.toDouble)).toDS(), Set(1.toDouble, 2.toDouble))
+ checkDataset(Seq(Set(1.toFloat, 2.toFloat)).toDS(), Set(1.toFloat, 2.toFloat))
+ checkDataset(Seq(Set(1.toByte, 2.toByte)).toDS(), Set(1.toByte, 2.toByte))
+ checkDataset(Seq(Set(1.toShort, 2.toShort)).toDS(), Set(1.toShort, 2.toShort))
+ checkDataset(Seq(Set(true, false)).toDS(), Set(true, false))
+ checkDataset(Seq(Set("test1", "test2")).toDS(), Set("test1", "test2"))
+ checkDataset(Seq(Set(Tuple1(1), Tuple1(2))).toDS(), Set(Tuple1(1), Tuple1(2)))
+
+ checkDataset(Seq(HSet(1, 2)).toDS(), HSet(1, 2))
+ checkDataset(Seq(HSet(1.toLong, 2.toLong)).toDS(), HSet(1.toLong, 2.toLong))
+ checkDataset(Seq(HSet(1.toDouble, 2.toDouble)).toDS(), HSet(1.toDouble, 2.toDouble))
+ checkDataset(Seq(HSet(1.toFloat, 2.toFloat)).toDS(), HSet(1.toFloat, 2.toFloat))
+ checkDataset(Seq(HSet(1.toByte, 2.toByte)).toDS(), HSet(1.toByte, 2.toByte))
+ checkDataset(Seq(HSet(1.toShort, 2.toShort)).toDS(), HSet(1.toShort, 2.toShort))
+ checkDataset(Seq(HSet(true, false)).toDS(), HSet(true, false))
+ checkDataset(Seq(HSet("test1", "test2")).toDS(), HSet("test1", "test2"))
+ checkDataset(Seq(HSet(Tuple1(1), Tuple1(2))).toDS(), HSet(Tuple1(1), Tuple1(2)))
+
+ checkDataset(Seq(Seq(Some(1), None), Seq(Some(2))).toDF("c").as[Set[Integer]],
+ Seq(Set[Integer](1, null), Set[Integer](2)): _*)
+ }
+
test("nested sequences") {
checkDataset(Seq(Seq(Seq(1))).toDS(), Seq(Seq(1)))
checkDataset(Seq(List(Queue(1))).toDS(), List(Queue(1)))
@@ -352,6 +378,11 @@ class DatasetPrimitiveSuite extends QueryTest with SharedSQLContext {
checkDataset(Seq(LHMap(Map(1 -> 2) -> 3)).toDS(), LHMap(Map(1 -> 2) -> 3))
}
+ test("nested set") {
+ checkDataset(Seq(Set(HSet(1, 2), HSet(3, 4))).toDS(), Set(HSet(1, 2), HSet(3, 4)))
+ checkDataset(Seq(HSet(Set(1, 2), Set(3, 4))).toDS(), HSet(Set(1, 2), Set(3, 4)))
+ }
+
test("package objects") {
import packageobject._
checkDataset(Seq(PackageClass(1)).toDS(), PackageClass(1))
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org