You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ma...@apache.org on 2016/02/02 19:15:45 UTC

spark git commit: [SPARK-13094][SQL] Add encoders for seq/array of primitives

Repository: spark
Updated Branches:
  refs/heads/master 12a20c144 -> 29d92181d


[SPARK-13094][SQL] Add encoders for seq/array of primitives

Author: Michael Armbrust <mi...@databricks.com>

Closes #11014 from marmbrus/seqEncoders.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/29d92181
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/29d92181
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/29d92181

Branch: refs/heads/master
Commit: 29d92181d0c49988c387d34e4a71b1afe02c29e2
Parents: 12a20c1
Author: Michael Armbrust <mi...@databricks.com>
Authored: Tue Feb 2 10:15:40 2016 -0800
Committer: Michael Armbrust <mi...@databricks.com>
Committed: Tue Feb 2 10:15:40 2016 -0800

----------------------------------------------------------------------
 .../org/apache/spark/sql/SQLImplicits.scala     | 63 +++++++++++++++++++-
 .../spark/sql/DatasetPrimitiveSuite.scala       | 22 +++++++
 .../scala/org/apache/spark/sql/QueryTest.scala  |  8 ++-
 3 files changed, 91 insertions(+), 2 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/29d92181/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala
index ab41479..16c4095 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala
@@ -39,6 +39,8 @@ abstract class SQLImplicits {
   /** @since 1.6.0 */
   implicit def newProductEncoder[T <: Product : TypeTag]: Encoder[T] = ExpressionEncoder()
 
+  // Primitives
+
   /** @since 1.6.0 */
   implicit def newIntEncoder: Encoder[Int] = ExpressionEncoder()
 
@@ -56,13 +58,72 @@ abstract class SQLImplicits {
 
   /** @since 1.6.0 */
   implicit def newShortEncoder: Encoder[Short] = ExpressionEncoder()
-  /** @since 1.6.0 */
 
+  /** @since 1.6.0 */
   implicit def newBooleanEncoder: Encoder[Boolean] = ExpressionEncoder()
 
   /** @since 1.6.0 */
   implicit def newStringEncoder: Encoder[String] = ExpressionEncoder()
 
+  // Seqs
+
+  /** @since 1.6.1 */
+  implicit def newIntSeqEncoder: Encoder[Seq[Int]] = ExpressionEncoder()
+
+  /** @since 1.6.1 */
+  implicit def newLongSeqEncoder: Encoder[Seq[Long]] = ExpressionEncoder()
+
+  /** @since 1.6.1 */
+  implicit def newDoubleSeqEncoder: Encoder[Seq[Double]] = ExpressionEncoder()
+
+  /** @since 1.6.1 */
+  implicit def newFloatSeqEncoder: Encoder[Seq[Float]] = ExpressionEncoder()
+
+  /** @since 1.6.1 */
+  implicit def newByteSeqEncoder: Encoder[Seq[Byte]] = ExpressionEncoder()
+
+  /** @since 1.6.1 */
+  implicit def newShortSeqEncoder: Encoder[Seq[Short]] = ExpressionEncoder()
+
+  /** @since 1.6.1 */
+  implicit def newBooleanSeqEncoder: Encoder[Seq[Boolean]] = ExpressionEncoder()
+
+  /** @since 1.6.1 */
+  implicit def newStringSeqEncoder: Encoder[Seq[String]] = ExpressionEncoder()
+
+  /** @since 1.6.1 */
+  implicit def newProductSeqEncoder[A <: Product : TypeTag]: Encoder[Seq[A]] = ExpressionEncoder()
+
+  // Arrays
+
+  /** @since 1.6.1 */
+  implicit def newIntArrayEncoder: Encoder[Array[Int]] = ExpressionEncoder()
+
+  /** @since 1.6.1 */
+  implicit def newLongArrayEncoder: Encoder[Array[Long]] = ExpressionEncoder()
+
+  /** @since 1.6.1 */
+  implicit def newDoubleArrayEncoder: Encoder[Array[Double]] = ExpressionEncoder()
+
+  /** @since 1.6.1 */
+  implicit def newFloatArrayEncoder: Encoder[Array[Float]] = ExpressionEncoder()
+
+  /** @since 1.6.1 */
+  implicit def newByteArrayEncoder: Encoder[Array[Byte]] = ExpressionEncoder()
+
+  /** @since 1.6.1 */
+  implicit def newShortArrayEncoder: Encoder[Array[Short]] = ExpressionEncoder()
+
+  /** @since 1.6.1 */
+  implicit def newBooleanArrayEncoder: Encoder[Array[Boolean]] = ExpressionEncoder()
+
+  /** @since 1.6.1 */
+  implicit def newStringArrayEncoder: Encoder[Array[String]] = ExpressionEncoder()
+
+  /** @since 1.6.1 */
+  implicit def newProductArrayEncoder[A <: Product : TypeTag]: Encoder[Array[A]] =
+    ExpressionEncoder()
+
   /**
    * Creates a [[Dataset]] from an RDD.
    * @since 1.6.0

http://git-wip-us.apache.org/repos/asf/spark/blob/29d92181/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala
index f75d096..243d13b 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala
@@ -105,4 +105,26 @@ class DatasetPrimitiveSuite extends QueryTest with SharedSQLContext {
       agged,
       "1", "abc", "3", "xyz", "5", "hello")
   }
+
+  test("Arrays and Lists") {
+    checkAnswer(Seq(Seq(1)).toDS(), Seq(1))
+    checkAnswer(Seq(Seq(1.toLong)).toDS(), Seq(1.toLong))
+    checkAnswer(Seq(Seq(1.toDouble)).toDS(), Seq(1.toDouble))
+    checkAnswer(Seq(Seq(1.toFloat)).toDS(), Seq(1.toFloat))
+    checkAnswer(Seq(Seq(1.toByte)).toDS(), Seq(1.toByte))
+    checkAnswer(Seq(Seq(1.toShort)).toDS(), Seq(1.toShort))
+    checkAnswer(Seq(Seq(true)).toDS(), Seq(true))
+    checkAnswer(Seq(Seq("test")).toDS(), Seq("test"))
+    checkAnswer(Seq(Seq(Tuple1(1))).toDS(), Seq(Tuple1(1)))
+
+    checkAnswer(Seq(Array(1)).toDS(), Array(1))
+    checkAnswer(Seq(Array(1.toLong)).toDS(), Array(1.toLong))
+    checkAnswer(Seq(Array(1.toDouble)).toDS(), Array(1.toDouble))
+    checkAnswer(Seq(Array(1.toFloat)).toDS(), Array(1.toFloat))
+    checkAnswer(Seq(Array(1.toByte)).toDS(), Array(1.toByte))
+    checkAnswer(Seq(Array(1.toShort)).toDS(), Array(1.toShort))
+    checkAnswer(Seq(Array(true)).toDS(), Array(true))
+    checkAnswer(Seq(Array("test")).toDS(), Array("test"))
+    checkAnswer(Seq(Array(Tuple1(1))).toDS(), Array(Tuple1(1)))
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/29d92181/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
index 405e589..5401212 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
@@ -95,7 +95,13 @@ abstract class QueryTest extends PlanTest {
            """.stripMargin, e)
     }
 
-    if (decoded != expectedAnswer.toSet) {
+    // Handle the case where the return type is an array
+    val isArray = decoded.headOption.map(_.getClass.isArray).getOrElse(false)
+    def normalEquality = decoded == expectedAnswer.toSet
+    def expectedAsSeq = expectedAnswer.map(_.asInstanceOf[Array[_]].toSeq).toSet
+    def decodedAsSeq = decoded.map(_.asInstanceOf[Array[_]].toSeq)
+
+    if (!((isArray && expectedAsSeq == decodedAsSeq) || normalEquality)) {
       val expected = expectedAnswer.toSet.toSeq.map((a: Any) => a.toString).sorted
       val actual = decoded.toSet.toSeq.map((a: Any) => a.toString).sorted
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org