You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ma...@apache.org on 2015/01/21 00:16:27 UTC
[3/3] spark git commit: [SPARK-5323][SQL] Remove Row's Seq
inheritance.
[SPARK-5323][SQL] Remove Row's Seq inheritance.
Author: Reynold Xin <rx...@databricks.com>
Closes #4115 from rxin/row-seq and squashes the following commits:
e33abd8 [Reynold Xin] Fixed compilation error.
cceb650 [Reynold Xin] Python test fixes, and removal of WrapDynamic.
0334a52 [Reynold Xin] mkString.
9cdeb7d [Reynold Xin] Hive tests.
15681c2 [Reynold Xin] Fix more test cases.
ea9023a [Reynold Xin] Fixed a catalyst test.
c5e2cb5 [Reynold Xin] Minor patch up.
b9cab7c [Reynold Xin] [SPARK-5323][SQL] Remove Row's Seq inheritance.
Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/d181c2a1
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/d181c2a1
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/d181c2a1
Branch: refs/heads/master
Commit: d181c2a1fc40746947b97799b12e7dd8c213fa9c
Parents: bc20a52
Author: Reynold Xin <rx...@databricks.com>
Authored: Tue Jan 20 15:16:14 2015 -0800
Committer: Michael Armbrust <mi...@databricks.com>
Committed: Tue Jan 20 15:16:14 2015 -0800
----------------------------------------------------------------------
.../main/scala/org/apache/spark/sql/Row.scala | 75 +++-
.../spark/sql/catalyst/ScalaReflection.scala | 3 +-
.../apache/spark/sql/catalyst/dsl/package.scala | 3 -
.../spark/sql/catalyst/expressions/Cast.scala | 3 +-
.../sql/catalyst/expressions/Projection.scala | 310 ++++++++------
.../expressions/SpecificMutableRow.scala | 4 +-
.../sql/catalyst/expressions/WrapDynamic.scala | 64 ---
.../codegen/GenerateProjection.scala | 22 +-
.../spark/sql/catalyst/expressions/rows.scala | 6 +-
.../sql/catalyst/ScalaReflectionSuite.scala | 2 +-
.../scala/org/apache/spark/sql/SchemaRDD.scala | 19 -
.../columnar/InMemoryColumnarTableScan.scala | 9 +-
.../compression/compressionSchemes.scala | 2 +-
.../spark/sql/execution/debug/package.scala | 2 +-
.../apache/spark/sql/execution/pythonUdfs.scala | 7 +-
.../org/apache/spark/sql/json/JsonRDD.scala | 8 +-
.../spark/sql/parquet/ParquetConverter.scala | 2 +-
.../org/apache/spark/sql/DslQuerySuite.scala | 146 +++----
.../scala/org/apache/spark/sql/JoinSuite.scala | 242 +++++------
.../scala/org/apache/spark/sql/QueryTest.scala | 31 +-
.../org/apache/spark/sql/SQLQuerySuite.scala | 416 ++++++++++---------
.../sql/ScalaReflectionRelationSuite.scala | 6 +-
.../spark/sql/columnar/ColumnStatsSuite.scala | 6 +-
.../columnar/InMemoryColumnarQuerySuite.scala | 18 +-
.../columnar/PartitionBatchPruningSuite.scala | 2 +-
.../compression/BooleanBitSetSuite.scala | 2 +-
.../apache/spark/sql/execution/TgfSuite.scala | 4 +-
.../org/apache/spark/sql/json/JsonSuite.scala | 185 +++++----
.../spark/sql/parquet/ParquetFilterSuite.scala | 91 ++--
.../spark/sql/parquet/ParquetIOSuite.scala | 12 +-
.../spark/sql/parquet/ParquetQuerySuite.scala | 28 +-
.../spark/sql/parquet/ParquetQuerySuite2.scala | 2 +-
.../spark/sql/sources/TableScanSuite.scala | 4 +-
.../org/apache/spark/sql/hive/HiveContext.scala | 8 +-
.../apache/spark/sql/hive/HiveInspectors.scala | 20 +-
.../org/apache/spark/sql/hive/hiveUdfs.scala | 2 +-
.../spark/sql/hive/hiveWriterContainers.scala | 2 +-
.../scala/org/apache/spark/sql/QueryTest.scala | 48 ++-
.../spark/sql/hive/HiveInspectorSuite.scala | 8 +-
.../sql/hive/InsertIntoHiveTableSuite.scala | 6 +-
.../sql/hive/MetastoreDataSourcesSuite.scala | 12 +-
.../apache/spark/sql/hive/StatisticsSuite.scala | 8 +-
.../sql/hive/execution/HiveQuerySuite.scala | 34 +-
.../spark/sql/hive/execution/HiveUdfSuite.scala | 12 +-
.../sql/hive/execution/SQLQuerySuite.scala | 18 +-
.../spark/sql/parquet/HiveParquetSuite.scala | 2 +-
.../spark/sql/parquet/parquetSuites.scala | 58 +--
47 files changed, 1018 insertions(+), 956 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/spark/blob/d181c2a1/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala
index 208ec92..41bb4f0 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala
@@ -17,6 +17,8 @@
package org.apache.spark.sql
+import scala.util.hashing.MurmurHash3
+
import org.apache.spark.sql.catalyst.expressions.GenericRow
@@ -32,7 +34,7 @@ object Row {
* }
* }}}
*/
- def unapplySeq(row: Row): Some[Seq[Any]] = Some(row)
+ def unapplySeq(row: Row): Some[Seq[Any]] = Some(row.toSeq)
/**
* This method can be used to construct a [[Row]] with the given values.
@@ -43,6 +45,16 @@ object Row {
* This method can be used to construct a [[Row]] from a [[Seq]] of values.
*/
def fromSeq(values: Seq[Any]): Row = new GenericRow(values.toArray)
+
+ def fromTuple(tuple: Product): Row = fromSeq(tuple.productIterator.toSeq)
+
+ /**
+ * Merge multiple rows into a single row, one after another.
+ */
+ def merge(rows: Row*): Row = {
+ // TODO: Improve the performance of this if used in performance critical part.
+ new GenericRow(rows.flatMap(_.toSeq).toArray)
+ }
}
@@ -103,7 +115,13 @@ object Row {
*
* @group row
*/
-trait Row extends Seq[Any] with Serializable {
+trait Row extends Serializable {
+ /** Number of elements in the Row. */
+ def size: Int = length
+
+ /** Number of elements in the Row. */
+ def length: Int
+
/**
* Returns the value at position i. If the value is null, null is returned. The following
* is a mapping between Spark SQL types and return types:
@@ -291,12 +309,61 @@ trait Row extends Seq[Any] with Serializable {
/** Returns true if there are any NULL values in this row. */
def anyNull: Boolean = {
- val l = length
+ val len = length
var i = 0
- while (i < l) {
+ while (i < len) {
if (isNullAt(i)) { return true }
i += 1
}
false
}
+
+ override def equals(that: Any): Boolean = that match {
+ case null => false
+ case that: Row =>
+ if (this.length != that.length) {
+ return false
+ }
+ var i = 0
+ val len = this.length
+ while (i < len) {
+ if (apply(i) != that.apply(i)) {
+ return false
+ }
+ i += 1
+ }
+ true
+ case _ => false
+ }
+
+ override def hashCode: Int = {
+ // Using Scala's Seq hash code implementation.
+ var n = 0
+ var h = MurmurHash3.seqSeed
+ val len = length
+ while (n < len) {
+ h = MurmurHash3.mix(h, apply(n).##)
+ n += 1
+ }
+ MurmurHash3.finalizeHash(h, n)
+ }
+
+ /* ---------------------- utility methods for Scala ---------------------- */
+
+ /**
+ * Return a Scala Seq representing the row. ELements are placed in the same order in the Seq.
+ */
+ def toSeq: Seq[Any]
+
+ /** Displays all elements of this sequence in a string (without a separator). */
+ def mkString: String = toSeq.mkString
+
+ /** Displays all elements of this sequence in a string using a separator string. */
+ def mkString(sep: String): String = toSeq.mkString(sep)
+
+ /**
+ * Displays all elements of this traversable or iterator in a string using
+ * start, end, and separator strings.
+ */
+ def mkString(start: String, sep: String, end: String): String = toSeq.mkString(start, sep, end)
}
http://git-wip-us.apache.org/repos/asf/spark/blob/d181c2a1/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
index d280db8..191d16f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
@@ -84,8 +84,9 @@ trait ScalaReflection {
}
def convertRowToScala(r: Row, schema: StructType): Row = {
+ // TODO: This is very slow!!!
new GenericRow(
- r.zip(schema.fields.map(_.dataType))
+ r.toSeq.zip(schema.fields.map(_.dataType))
.map(r_dt => convertToScala(r_dt._1, r_dt._2)).toArray)
}
http://git-wip-us.apache.org/repos/asf/spark/blob/d181c2a1/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
index 26c8558..417659e 100755
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
@@ -272,9 +272,6 @@ package object dsl {
def sfilter[T1](arg1: Symbol)(udf: (T1) => Boolean) =
Filter(ScalaUdf(udf, BooleanType, Seq(UnresolvedAttribute(arg1.name))), logicalPlan)
- def sfilter(dynamicUdf: (DynamicRow) => Boolean) =
- Filter(ScalaUdf(dynamicUdf, BooleanType, Seq(WrapDynamic(logicalPlan.output))), logicalPlan)
-
def sample(
fraction: Double,
withReplacement: Boolean = true,
http://git-wip-us.apache.org/repos/asf/spark/blob/d181c2a1/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
index 1a2133b..ece5ee7 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
@@ -407,7 +407,8 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
val casts = from.fields.zip(to.fields).map {
case (fromField, toField) => cast(fromField.dataType, toField.dataType)
}
- buildCast[Row](_, row => Row(row.zip(casts).map {
+ // TODO: This is very slow!
+ buildCast[Row](_, row => Row(row.toSeq.zip(casts).map {
case (v, cast) => if (v == null) null else cast(v)
}: _*))
}
http://git-wip-us.apache.org/repos/asf/spark/blob/d181c2a1/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
index e7e81a2..db5d897 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
@@ -105,45 +105,45 @@ class JoinedRow extends Row {
this
}
- def iterator = row1.iterator ++ row2.iterator
+ override def toSeq: Seq[Any] = row1.toSeq ++ row2.toSeq
- def length = row1.length + row2.length
+ override def length = row1.length + row2.length
- def apply(i: Int) =
- if (i < row1.size) row1(i) else row2(i - row1.size)
+ override def apply(i: Int) =
+ if (i < row1.length) row1(i) else row2(i - row1.length)
- def isNullAt(i: Int) =
- if (i < row1.size) row1.isNullAt(i) else row2.isNullAt(i - row1.size)
+ override def isNullAt(i: Int) =
+ if (i < row1.length) row1.isNullAt(i) else row2.isNullAt(i - row1.length)
- def getInt(i: Int): Int =
- if (i < row1.size) row1.getInt(i) else row2.getInt(i - row1.size)
+ override def getInt(i: Int): Int =
+ if (i < row1.length) row1.getInt(i) else row2.getInt(i - row1.length)
- def getLong(i: Int): Long =
- if (i < row1.size) row1.getLong(i) else row2.getLong(i - row1.size)
+ override def getLong(i: Int): Long =
+ if (i < row1.length) row1.getLong(i) else row2.getLong(i - row1.length)
- def getDouble(i: Int): Double =
- if (i < row1.size) row1.getDouble(i) else row2.getDouble(i - row1.size)
+ override def getDouble(i: Int): Double =
+ if (i < row1.length) row1.getDouble(i) else row2.getDouble(i - row1.length)
- def getBoolean(i: Int): Boolean =
- if (i < row1.size) row1.getBoolean(i) else row2.getBoolean(i - row1.size)
+ override def getBoolean(i: Int): Boolean =
+ if (i < row1.length) row1.getBoolean(i) else row2.getBoolean(i - row1.length)
- def getShort(i: Int): Short =
- if (i < row1.size) row1.getShort(i) else row2.getShort(i - row1.size)
+ override def getShort(i: Int): Short =
+ if (i < row1.length) row1.getShort(i) else row2.getShort(i - row1.length)
- def getByte(i: Int): Byte =
- if (i < row1.size) row1.getByte(i) else row2.getByte(i - row1.size)
+ override def getByte(i: Int): Byte =
+ if (i < row1.length) row1.getByte(i) else row2.getByte(i - row1.length)
- def getFloat(i: Int): Float =
- if (i < row1.size) row1.getFloat(i) else row2.getFloat(i - row1.size)
+ override def getFloat(i: Int): Float =
+ if (i < row1.length) row1.getFloat(i) else row2.getFloat(i - row1.length)
- def getString(i: Int): String =
- if (i < row1.size) row1.getString(i) else row2.getString(i - row1.size)
+ override def getString(i: Int): String =
+ if (i < row1.length) row1.getString(i) else row2.getString(i - row1.length)
override def getAs[T](i: Int): T =
- if (i < row1.size) row1.getAs[T](i) else row2.getAs[T](i - row1.size)
+ if (i < row1.length) row1.getAs[T](i) else row2.getAs[T](i - row1.length)
- def copy() = {
- val totalSize = row1.size + row2.size
+ override def copy() = {
+ val totalSize = row1.length + row2.length
val copiedValues = new Array[Any](totalSize)
var i = 0
while(i < totalSize) {
@@ -154,8 +154,16 @@ class JoinedRow extends Row {
}
override def toString() = {
- val row = (if (row1 != null) row1 else Seq[Any]()) ++ (if (row2 != null) row2 else Seq[Any]())
- s"[${row.mkString(",")}]"
+ // Make sure toString never throws NullPointerException.
+ if ((row1 eq null) && (row2 eq null)) {
+ "[ empty row ]"
+ } else if (row1 eq null) {
+ row2.mkString("[", ",", "]")
+ } else if (row2 eq null) {
+ row1.mkString("[", ",", "]")
+ } else {
+ mkString("[", ",", "]")
+ }
}
}
@@ -197,45 +205,45 @@ class JoinedRow2 extends Row {
this
}
- def iterator = row1.iterator ++ row2.iterator
+ override def toSeq: Seq[Any] = row1.toSeq ++ row2.toSeq
- def length = row1.length + row2.length
+ override def length = row1.length + row2.length
- def apply(i: Int) =
- if (i < row1.size) row1(i) else row2(i - row1.size)
+ override def apply(i: Int) =
+ if (i < row1.length) row1(i) else row2(i - row1.length)
- def isNullAt(i: Int) =
- if (i < row1.size) row1.isNullAt(i) else row2.isNullAt(i - row1.size)
+ override def isNullAt(i: Int) =
+ if (i < row1.length) row1.isNullAt(i) else row2.isNullAt(i - row1.length)
- def getInt(i: Int): Int =
- if (i < row1.size) row1.getInt(i) else row2.getInt(i - row1.size)
+ override def getInt(i: Int): Int =
+ if (i < row1.length) row1.getInt(i) else row2.getInt(i - row1.length)
- def getLong(i: Int): Long =
- if (i < row1.size) row1.getLong(i) else row2.getLong(i - row1.size)
+ override def getLong(i: Int): Long =
+ if (i < row1.length) row1.getLong(i) else row2.getLong(i - row1.length)
- def getDouble(i: Int): Double =
- if (i < row1.size) row1.getDouble(i) else row2.getDouble(i - row1.size)
+ override def getDouble(i: Int): Double =
+ if (i < row1.length) row1.getDouble(i) else row2.getDouble(i - row1.length)
- def getBoolean(i: Int): Boolean =
- if (i < row1.size) row1.getBoolean(i) else row2.getBoolean(i - row1.size)
+ override def getBoolean(i: Int): Boolean =
+ if (i < row1.length) row1.getBoolean(i) else row2.getBoolean(i - row1.length)
- def getShort(i: Int): Short =
- if (i < row1.size) row1.getShort(i) else row2.getShort(i - row1.size)
+ override def getShort(i: Int): Short =
+ if (i < row1.length) row1.getShort(i) else row2.getShort(i - row1.length)
- def getByte(i: Int): Byte =
- if (i < row1.size) row1.getByte(i) else row2.getByte(i - row1.size)
+ override def getByte(i: Int): Byte =
+ if (i < row1.length) row1.getByte(i) else row2.getByte(i - row1.length)
- def getFloat(i: Int): Float =
- if (i < row1.size) row1.getFloat(i) else row2.getFloat(i - row1.size)
+ override def getFloat(i: Int): Float =
+ if (i < row1.length) row1.getFloat(i) else row2.getFloat(i - row1.length)
- def getString(i: Int): String =
- if (i < row1.size) row1.getString(i) else row2.getString(i - row1.size)
+ override def getString(i: Int): String =
+ if (i < row1.length) row1.getString(i) else row2.getString(i - row1.length)
override def getAs[T](i: Int): T =
- if (i < row1.size) row1.getAs[T](i) else row2.getAs[T](i - row1.size)
+ if (i < row1.length) row1.getAs[T](i) else row2.getAs[T](i - row1.length)
- def copy() = {
- val totalSize = row1.size + row2.size
+ override def copy() = {
+ val totalSize = row1.length + row2.length
val copiedValues = new Array[Any](totalSize)
var i = 0
while(i < totalSize) {
@@ -246,8 +254,16 @@ class JoinedRow2 extends Row {
}
override def toString() = {
- val row = (if (row1 != null) row1 else Seq[Any]()) ++ (if (row2 != null) row2 else Seq[Any]())
- s"[${row.mkString(",")}]"
+ // Make sure toString never throws NullPointerException.
+ if ((row1 eq null) && (row2 eq null)) {
+ "[ empty row ]"
+ } else if (row1 eq null) {
+ row2.mkString("[", ",", "]")
+ } else if (row2 eq null) {
+ row1.mkString("[", ",", "]")
+ } else {
+ mkString("[", ",", "]")
+ }
}
}
@@ -283,45 +299,45 @@ class JoinedRow3 extends Row {
this
}
- def iterator = row1.iterator ++ row2.iterator
+ override def toSeq: Seq[Any] = row1.toSeq ++ row2.toSeq
- def length = row1.length + row2.length
+ override def length = row1.length + row2.length
- def apply(i: Int) =
- if (i < row1.size) row1(i) else row2(i - row1.size)
+ override def apply(i: Int) =
+ if (i < row1.length) row1(i) else row2(i - row1.length)
- def isNullAt(i: Int) =
- if (i < row1.size) row1.isNullAt(i) else row2.isNullAt(i - row1.size)
+ override def isNullAt(i: Int) =
+ if (i < row1.length) row1.isNullAt(i) else row2.isNullAt(i - row1.length)
- def getInt(i: Int): Int =
- if (i < row1.size) row1.getInt(i) else row2.getInt(i - row1.size)
+ override def getInt(i: Int): Int =
+ if (i < row1.length) row1.getInt(i) else row2.getInt(i - row1.length)
- def getLong(i: Int): Long =
- if (i < row1.size) row1.getLong(i) else row2.getLong(i - row1.size)
+ override def getLong(i: Int): Long =
+ if (i < row1.length) row1.getLong(i) else row2.getLong(i - row1.length)
- def getDouble(i: Int): Double =
- if (i < row1.size) row1.getDouble(i) else row2.getDouble(i - row1.size)
+ override def getDouble(i: Int): Double =
+ if (i < row1.length) row1.getDouble(i) else row2.getDouble(i - row1.length)
- def getBoolean(i: Int): Boolean =
- if (i < row1.size) row1.getBoolean(i) else row2.getBoolean(i - row1.size)
+ override def getBoolean(i: Int): Boolean =
+ if (i < row1.length) row1.getBoolean(i) else row2.getBoolean(i - row1.length)
- def getShort(i: Int): Short =
- if (i < row1.size) row1.getShort(i) else row2.getShort(i - row1.size)
+ override def getShort(i: Int): Short =
+ if (i < row1.length) row1.getShort(i) else row2.getShort(i - row1.length)
- def getByte(i: Int): Byte =
- if (i < row1.size) row1.getByte(i) else row2.getByte(i - row1.size)
+ override def getByte(i: Int): Byte =
+ if (i < row1.length) row1.getByte(i) else row2.getByte(i - row1.length)
- def getFloat(i: Int): Float =
- if (i < row1.size) row1.getFloat(i) else row2.getFloat(i - row1.size)
+ override def getFloat(i: Int): Float =
+ if (i < row1.length) row1.getFloat(i) else row2.getFloat(i - row1.length)
- def getString(i: Int): String =
- if (i < row1.size) row1.getString(i) else row2.getString(i - row1.size)
+ override def getString(i: Int): String =
+ if (i < row1.length) row1.getString(i) else row2.getString(i - row1.length)
override def getAs[T](i: Int): T =
- if (i < row1.size) row1.getAs[T](i) else row2.getAs[T](i - row1.size)
+ if (i < row1.length) row1.getAs[T](i) else row2.getAs[T](i - row1.length)
- def copy() = {
- val totalSize = row1.size + row2.size
+ override def copy() = {
+ val totalSize = row1.length + row2.length
val copiedValues = new Array[Any](totalSize)
var i = 0
while(i < totalSize) {
@@ -332,8 +348,16 @@ class JoinedRow3 extends Row {
}
override def toString() = {
- val row = (if (row1 != null) row1 else Seq[Any]()) ++ (if (row2 != null) row2 else Seq[Any]())
- s"[${row.mkString(",")}]"
+ // Make sure toString never throws NullPointerException.
+ if ((row1 eq null) && (row2 eq null)) {
+ "[ empty row ]"
+ } else if (row1 eq null) {
+ row2.mkString("[", ",", "]")
+ } else if (row2 eq null) {
+ row1.mkString("[", ",", "]")
+ } else {
+ mkString("[", ",", "]")
+ }
}
}
@@ -369,45 +393,45 @@ class JoinedRow4 extends Row {
this
}
- def iterator = row1.iterator ++ row2.iterator
+ override def toSeq: Seq[Any] = row1.toSeq ++ row2.toSeq
- def length = row1.length + row2.length
+ override def length = row1.length + row2.length
- def apply(i: Int) =
- if (i < row1.size) row1(i) else row2(i - row1.size)
+ override def apply(i: Int) =
+ if (i < row1.length) row1(i) else row2(i - row1.length)
- def isNullAt(i: Int) =
- if (i < row1.size) row1.isNullAt(i) else row2.isNullAt(i - row1.size)
+ override def isNullAt(i: Int) =
+ if (i < row1.length) row1.isNullAt(i) else row2.isNullAt(i - row1.length)
- def getInt(i: Int): Int =
- if (i < row1.size) row1.getInt(i) else row2.getInt(i - row1.size)
+ override def getInt(i: Int): Int =
+ if (i < row1.length) row1.getInt(i) else row2.getInt(i - row1.length)
- def getLong(i: Int): Long =
- if (i < row1.size) row1.getLong(i) else row2.getLong(i - row1.size)
+ override def getLong(i: Int): Long =
+ if (i < row1.length) row1.getLong(i) else row2.getLong(i - row1.length)
- def getDouble(i: Int): Double =
- if (i < row1.size) row1.getDouble(i) else row2.getDouble(i - row1.size)
+ override def getDouble(i: Int): Double =
+ if (i < row1.length) row1.getDouble(i) else row2.getDouble(i - row1.length)
- def getBoolean(i: Int): Boolean =
- if (i < row1.size) row1.getBoolean(i) else row2.getBoolean(i - row1.size)
+ override def getBoolean(i: Int): Boolean =
+ if (i < row1.length) row1.getBoolean(i) else row2.getBoolean(i - row1.length)
- def getShort(i: Int): Short =
- if (i < row1.size) row1.getShort(i) else row2.getShort(i - row1.size)
+ override def getShort(i: Int): Short =
+ if (i < row1.length) row1.getShort(i) else row2.getShort(i - row1.length)
- def getByte(i: Int): Byte =
- if (i < row1.size) row1.getByte(i) else row2.getByte(i - row1.size)
+ override def getByte(i: Int): Byte =
+ if (i < row1.length) row1.getByte(i) else row2.getByte(i - row1.length)
- def getFloat(i: Int): Float =
- if (i < row1.size) row1.getFloat(i) else row2.getFloat(i - row1.size)
+ override def getFloat(i: Int): Float =
+ if (i < row1.length) row1.getFloat(i) else row2.getFloat(i - row1.length)
- def getString(i: Int): String =
- if (i < row1.size) row1.getString(i) else row2.getString(i - row1.size)
+ override def getString(i: Int): String =
+ if (i < row1.length) row1.getString(i) else row2.getString(i - row1.length)
override def getAs[T](i: Int): T =
- if (i < row1.size) row1.getAs[T](i) else row2.getAs[T](i - row1.size)
+ if (i < row1.length) row1.getAs[T](i) else row2.getAs[T](i - row1.length)
- def copy() = {
- val totalSize = row1.size + row2.size
+ override def copy() = {
+ val totalSize = row1.length + row2.length
val copiedValues = new Array[Any](totalSize)
var i = 0
while(i < totalSize) {
@@ -418,8 +442,16 @@ class JoinedRow4 extends Row {
}
override def toString() = {
- val row = (if (row1 != null) row1 else Seq[Any]()) ++ (if (row2 != null) row2 else Seq[Any]())
- s"[${row.mkString(",")}]"
+ // Make sure toString never throws NullPointerException.
+ if ((row1 eq null) && (row2 eq null)) {
+ "[ empty row ]"
+ } else if (row1 eq null) {
+ row2.mkString("[", ",", "]")
+ } else if (row2 eq null) {
+ row1.mkString("[", ",", "]")
+ } else {
+ mkString("[", ",", "]")
+ }
}
}
@@ -455,45 +487,45 @@ class JoinedRow5 extends Row {
this
}
- def iterator = row1.iterator ++ row2.iterator
+ override def toSeq: Seq[Any] = row1.toSeq ++ row2.toSeq
- def length = row1.length + row2.length
+ override def length = row1.length + row2.length
- def apply(i: Int) =
- if (i < row1.size) row1(i) else row2(i - row1.size)
+ override def apply(i: Int) =
+ if (i < row1.length) row1(i) else row2(i - row1.length)
- def isNullAt(i: Int) =
- if (i < row1.size) row1.isNullAt(i) else row2.isNullAt(i - row1.size)
+ override def isNullAt(i: Int) =
+ if (i < row1.length) row1.isNullAt(i) else row2.isNullAt(i - row1.length)
- def getInt(i: Int): Int =
- if (i < row1.size) row1.getInt(i) else row2.getInt(i - row1.size)
+ override def getInt(i: Int): Int =
+ if (i < row1.length) row1.getInt(i) else row2.getInt(i - row1.length)
- def getLong(i: Int): Long =
- if (i < row1.size) row1.getLong(i) else row2.getLong(i - row1.size)
+ override def getLong(i: Int): Long =
+ if (i < row1.length) row1.getLong(i) else row2.getLong(i - row1.length)
- def getDouble(i: Int): Double =
- if (i < row1.size) row1.getDouble(i) else row2.getDouble(i - row1.size)
+ override def getDouble(i: Int): Double =
+ if (i < row1.length) row1.getDouble(i) else row2.getDouble(i - row1.length)
- def getBoolean(i: Int): Boolean =
- if (i < row1.size) row1.getBoolean(i) else row2.getBoolean(i - row1.size)
+ override def getBoolean(i: Int): Boolean =
+ if (i < row1.length) row1.getBoolean(i) else row2.getBoolean(i - row1.length)
- def getShort(i: Int): Short =
- if (i < row1.size) row1.getShort(i) else row2.getShort(i - row1.size)
+ override def getShort(i: Int): Short =
+ if (i < row1.length) row1.getShort(i) else row2.getShort(i - row1.length)
- def getByte(i: Int): Byte =
- if (i < row1.size) row1.getByte(i) else row2.getByte(i - row1.size)
+ override def getByte(i: Int): Byte =
+ if (i < row1.length) row1.getByte(i) else row2.getByte(i - row1.length)
- def getFloat(i: Int): Float =
- if (i < row1.size) row1.getFloat(i) else row2.getFloat(i - row1.size)
+ override def getFloat(i: Int): Float =
+ if (i < row1.length) row1.getFloat(i) else row2.getFloat(i - row1.length)
- def getString(i: Int): String =
- if (i < row1.size) row1.getString(i) else row2.getString(i - row1.size)
+ override def getString(i: Int): String =
+ if (i < row1.length) row1.getString(i) else row2.getString(i - row1.length)
override def getAs[T](i: Int): T =
- if (i < row1.size) row1.getAs[T](i) else row2.getAs[T](i - row1.size)
+ if (i < row1.length) row1.getAs[T](i) else row2.getAs[T](i - row1.length)
- def copy() = {
- val totalSize = row1.size + row2.size
+ override def copy() = {
+ val totalSize = row1.length + row2.length
val copiedValues = new Array[Any](totalSize)
var i = 0
while(i < totalSize) {
@@ -504,7 +536,15 @@ class JoinedRow5 extends Row {
}
override def toString() = {
- val row = (if (row1 != null) row1 else Seq[Any]()) ++ (if (row2 != null) row2 else Seq[Any]())
- s"[${row.mkString(",")}]"
+ // Make sure toString never throws NullPointerException.
+ if ((row1 eq null) && (row2 eq null)) {
+ "[ empty row ]"
+ } else if (row1 eq null) {
+ row2.mkString("[", ",", "]")
+ } else if (row2 eq null) {
+ row1.mkString("[", ",", "]")
+ } else {
+ mkString("[", ",", "]")
+ }
}
}
http://git-wip-us.apache.org/repos/asf/spark/blob/d181c2a1/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala
index 37d9f0e..7434165 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala
@@ -209,6 +209,8 @@ final class SpecificMutableRow(val values: Array[MutableValue]) extends MutableR
override def length: Int = values.length
+ override def toSeq: Seq[Any] = values.map(_.boxed).toSeq
+
override def setNullAt(i: Int): Unit = {
values(i).isNull = true
}
@@ -231,8 +233,6 @@ final class SpecificMutableRow(val values: Array[MutableValue]) extends MutableR
if (value == null) setNullAt(ordinal) else values(ordinal).update(value)
}
- override def iterator: Iterator[Any] = values.map(_.boxed).iterator
-
override def setString(ordinal: Int, value: String) = update(ordinal, value)
override def getString(ordinal: Int) = apply(ordinal).asInstanceOf[String]
http://git-wip-us.apache.org/repos/asf/spark/blob/d181c2a1/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/WrapDynamic.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/WrapDynamic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/WrapDynamic.scala
deleted file mode 100644
index e2f5c73..0000000
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/WrapDynamic.scala
+++ /dev/null
@@ -1,64 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.catalyst.expressions
-
-import scala.language.dynamics
-
-import org.apache.spark.sql.types.DataType
-
-/**
- * The data type representing [[DynamicRow]] values.
- */
-case object DynamicType extends DataType {
-
- /**
- * The default size of a value of the DynamicType is 4096 bytes.
- */
- override def defaultSize: Int = 4096
-}
-
-/**
- * Wrap a [[Row]] as a [[DynamicRow]].
- */
-case class WrapDynamic(children: Seq[Attribute]) extends Expression {
- type EvaluatedType = DynamicRow
-
- def nullable = false
-
- def dataType = DynamicType
-
- override def eval(input: Row): DynamicRow = input match {
- // Avoid copy for generic rows.
- case g: GenericRow => new DynamicRow(children, g.values)
- case otherRowType => new DynamicRow(children, otherRowType.toArray)
- }
-}
-
-/**
- * DynamicRows use scala's Dynamic trait to emulate an ORM of in a dynamically typed language.
- * Since the type of the column is not known at compile time, all attributes are converted to
- * strings before being passed to the function.
- */
-class DynamicRow(val schema: Seq[Attribute], values: Array[Any])
- extends GenericRow(values) with Dynamic {
-
- def selectDynamic(attributeName: String): String = {
- val ordinal = schema.indexWhere(_.name == attributeName)
- values(ordinal).toString
- }
-}
http://git-wip-us.apache.org/repos/asf/spark/blob/d181c2a1/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala
index cc97cb4..69397a7 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala
@@ -77,14 +77,6 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] {
""".children : Seq[Tree]
}
- val iteratorFunction = {
- val allColumns = (0 until expressions.size).map { i =>
- val iLit = ru.Literal(Constant(i))
- q"if(isNullAt($iLit)) { null } else { ${newTermName(s"c$i")} }"
- }
- q"override def iterator = Iterator[Any](..$allColumns)"
- }
-
val accessorFailure = q"""scala.sys.error("Invalid ordinal:" + i)"""
val applyFunction = {
val cases = (0 until expressions.size).map { i =>
@@ -191,20 +183,26 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] {
}
"""
+ val allColumns = (0 until expressions.size).map { i =>
+ val iLit = ru.Literal(Constant(i))
+ q"if(isNullAt($iLit)) { null } else { ${newTermName(s"c$i")} }"
+ }
+
val copyFunction =
- q"""
- override def copy() = new $genericRowType(this.toArray)
- """
+ q"override def copy() = new $genericRowType(Array[Any](..$allColumns))"
+
+ val toSeqFunction =
+ q"override def toSeq: Seq[Any] = Seq(..$allColumns)"
val classBody =
nullFunctions ++ (
lengthDef +:
- iteratorFunction +:
applyFunction +:
updateFunction +:
equalsFunction +:
hashCodeFunction +:
copyFunction +:
+ toSeqFunction +:
(tupleElements ++ specificAccessorFunctions ++ specificMutatorFunctions))
val code = q"""
http://git-wip-us.apache.org/repos/asf/spark/blob/d181c2a1/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala
index c22b842..8df150e 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala
@@ -44,7 +44,7 @@ trait MutableRow extends Row {
*/
object EmptyRow extends Row {
override def apply(i: Int): Any = throw new UnsupportedOperationException
- override def iterator = Iterator.empty
+ override def toSeq = Seq.empty
override def length = 0
override def isNullAt(i: Int): Boolean = throw new UnsupportedOperationException
override def getInt(i: Int): Int = throw new UnsupportedOperationException
@@ -70,7 +70,7 @@ class GenericRow(protected[sql] val values: Array[Any]) extends Row {
def this(size: Int) = this(new Array[Any](size))
- override def iterator = values.iterator
+ override def toSeq = values.toSeq
override def length = values.length
@@ -119,7 +119,7 @@ class GenericRow(protected[sql] val values: Array[Any]) extends Row {
}
// Custom hashCode function that matches the efficient code generated version.
- override def hashCode(): Int = {
+ override def hashCode: Int = {
var result: Int = 37
var i = 0
http://git-wip-us.apache.org/repos/asf/spark/blob/d181c2a1/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala
index 6df5db4..5138942 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala
@@ -244,7 +244,7 @@ class ScalaReflectionSuite extends FunSuite {
test("convert PrimitiveData to catalyst") {
val data = PrimitiveData(1, 1, 1, 1, 1, 1, true)
- val convertedData = Seq(1, 1.toLong, 1.toDouble, 1.toFloat, 1.toShort, 1.toByte, true)
+ val convertedData = Row(1, 1.toLong, 1.toDouble, 1.toFloat, 1.toShort, 1.toByte, true)
val dataType = schemaFor[PrimitiveData].dataType
assert(convertToCatalyst(data, dataType) === convertedData)
}
http://git-wip-us.apache.org/repos/asf/spark/blob/d181c2a1/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala
index ae4d8ba..d1e21df 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala
@@ -332,25 +332,6 @@ class SchemaRDD(
/**
* :: Experimental ::
- * Filters tuples using a function over a `Dynamic` version of a given Row. DynamicRows use
- * scala's Dynamic trait to emulate an ORM of in a dynamically typed language. Since the type of
- * the column is not known at compile time, all attributes are converted to strings before
- * being passed to the function.
- *
- * {{{
- * schemaRDD.where(r => r.firstName == "Bob" && r.lastName == "Smith")
- * }}}
- *
- * @group Query
- */
- @Experimental
- def where(dynamicUdf: (DynamicRow) => Boolean) =
- new SchemaRDD(
- sqlContext,
- Filter(ScalaUdf(dynamicUdf, BooleanType, Seq(WrapDynamic(logicalPlan.output))), logicalPlan))
-
- /**
- * :: Experimental ::
* Returns a sampled version of the underlying dataset.
*
* @group Query
http://git-wip-us.apache.org/repos/asf/spark/blob/d181c2a1/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala
index 065fae3..11d5943 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala
@@ -21,7 +21,6 @@ import java.nio.ByteBuffer
import scala.collection.mutable.ArrayBuffer
-import org.apache.spark.SparkContext._
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation
@@ -128,8 +127,7 @@ private[sql] case class InMemoryRelation(
rowCount += 1
}
- val stats = Row.fromSeq(
- columnBuilders.map(_.columnStats.collectedStatistics).foldLeft(Seq.empty[Any])(_ ++ _))
+ val stats = Row.merge(columnBuilders.map(_.columnStats.collectedStatistics) : _*)
batchStats += stats
CachedBatch(columnBuilders.map(_.build().array()), stats)
@@ -271,9 +269,10 @@ private[sql] case class InMemoryColumnarTableScan(
// Extract rows via column accessors
new Iterator[Row] {
+ private[this] val rowLen = nextRow.length
override def next() = {
var i = 0
- while (i < nextRow.length) {
+ while (i < rowLen) {
columnAccessors(i).extractTo(nextRow, i)
i += 1
}
@@ -297,7 +296,7 @@ private[sql] case class InMemoryColumnarTableScan(
cachedBatchIterator.filter { cachedBatch =>
if (!partitionFilter(cachedBatch.stats)) {
def statsString = relation.partitionStatistics.schema
- .zip(cachedBatch.stats)
+ .zip(cachedBatch.stats.toSeq)
.map { case (a, s) => s"${a.name}: $s" }
.mkString(", ")
logInfo(s"Skipping partition based on stats $statsString")
http://git-wip-us.apache.org/repos/asf/spark/blob/d181c2a1/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/compressionSchemes.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/compressionSchemes.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/compressionSchemes.scala
index 6467324..68a5b1d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/compressionSchemes.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/compressionSchemes.scala
@@ -127,7 +127,7 @@ private[sql] case object RunLengthEncoding extends CompressionScheme {
while (from.hasRemaining) {
columnType.extract(from, value, 0)
- if (value.head == currentValue.head) {
+ if (value(0) == currentValue(0)) {
currentRun += 1
} else {
// Writes current run
http://git-wip-us.apache.org/repos/asf/spark/blob/d181c2a1/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala
index 46245cd..4d7e338 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala
@@ -144,7 +144,7 @@ package object debug {
case (null, _) =>
case (row: Row, StructType(fields)) =>
- row.zip(fields.map(_.dataType)).foreach { case(d,t) => typeCheck(d,t) }
+ row.toSeq.zip(fields.map(_.dataType)).foreach { case(d, t) => typeCheck(d, t) }
case (s: Seq[_], ArrayType(elemType, _)) =>
s.foreach(typeCheck(_, elemType))
case (m: Map[_, _], MapType(keyType, valueType, _)) =>
http://git-wip-us.apache.org/repos/asf/spark/blob/d181c2a1/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala
index 7ed64aa..b85021a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala
@@ -116,9 +116,9 @@ object EvaluatePython {
def toJava(obj: Any, dataType: DataType): Any = (obj, dataType) match {
case (null, _) => null
- case (row: Seq[Any], struct: StructType) =>
+ case (row: Row, struct: StructType) =>
val fields = struct.fields.map(field => field.dataType)
- row.zip(fields).map {
+ row.toSeq.zip(fields).map {
case (obj, dataType) => toJava(obj, dataType)
}.toArray
@@ -143,7 +143,8 @@ object EvaluatePython {
* Convert Row into Java Array (for pickled into Python)
*/
def rowToArray(row: Row, fields: Seq[DataType]): Array[Any] = {
- row.zip(fields).map {case (obj, dt) => toJava(obj, dt)}.toArray
+ // TODO: this is slow!
+ row.toSeq.zip(fields).map {case (obj, dt) => toJava(obj, dt)}.toArray
}
// Converts value to the type specified by the data type.
http://git-wip-us.apache.org/repos/asf/spark/blob/d181c2a1/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala
index db70a7e..9171939 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala
@@ -458,16 +458,16 @@ private[sql] object JsonRDD extends Logging {
gen.writeEndArray()
case (MapType(kv,vv, _), v: Map[_,_]) =>
- gen.writeStartObject
+ gen.writeStartObject()
v.foreach { p =>
gen.writeFieldName(p._1.toString)
valWriter(vv,p._2)
}
- gen.writeEndObject
+ gen.writeEndObject()
- case (StructType(ty), v: Seq[_]) =>
+ case (StructType(ty), v: Row) =>
gen.writeStartObject()
- ty.zip(v).foreach {
+ ty.zip(v.toSeq).foreach {
case (_, null) =>
case (field, v) =>
gen.writeFieldName(field.name)
http://git-wip-us.apache.org/repos/asf/spark/blob/d181c2a1/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala
index b4aed04..9d91502 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala
@@ -66,7 +66,7 @@ private[sql] object CatalystConverter {
// TODO: consider using Array[T] for arrays to avoid boxing of primitive types
type ArrayScalaType[T] = Seq[T]
- type StructScalaType[T] = Seq[T]
+ type StructScalaType[T] = Row
type MapScalaType[K, V] = Map[K, V]
protected[parquet] def createConverter(
http://git-wip-us.apache.org/repos/asf/spark/blob/d181c2a1/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala
index 2bcfe28..afbfe21 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala
@@ -45,28 +45,28 @@ class DslQuerySuite extends QueryTest {
test("agg") {
checkAnswer(
testData2.groupBy('a)('a, sum('b)),
- Seq((1,3),(2,3),(3,3))
+ Seq(Row(1,3), Row(2,3), Row(3,3))
)
checkAnswer(
testData2.groupBy('a)('a, sum('b) as 'totB).aggregate(sum('totB)),
- 9
+ Row(9)
)
checkAnswer(
testData2.aggregate(sum('b)),
- 9
+ Row(9)
)
}
test("convert $\"attribute name\" into unresolved attribute") {
checkAnswer(
testData.where($"key" === 1).select($"value"),
- Seq(Seq("1")))
+ Row("1"))
}
test("convert Scala Symbol 'attrname into unresolved attribute") {
checkAnswer(
testData.where('key === 1).select('value),
- Seq(Seq("1")))
+ Row("1"))
}
test("select *") {
@@ -78,61 +78,61 @@ class DslQuerySuite extends QueryTest {
test("simple select") {
checkAnswer(
testData.where('key === 1).select('value),
- Seq(Seq("1")))
+ Row("1"))
}
test("select with functions") {
checkAnswer(
testData.select(sum('value), avg('value), count(1)),
- Seq(Seq(5050.0, 50.5, 100)))
+ Row(5050.0, 50.5, 100))
checkAnswer(
testData2.select('a + 'b, 'a < 'b),
Seq(
- Seq(2, false),
- Seq(3, true),
- Seq(3, false),
- Seq(4, false),
- Seq(4, false),
- Seq(5, false)))
+ Row(2, false),
+ Row(3, true),
+ Row(3, false),
+ Row(4, false),
+ Row(4, false),
+ Row(5, false)))
checkAnswer(
testData2.select(sumDistinct('a)),
- Seq(Seq(6)))
+ Row(6))
}
test("global sorting") {
checkAnswer(
testData2.orderBy('a.asc, 'b.asc),
- Seq((1,1), (1,2), (2,1), (2,2), (3,1), (3,2)))
+ Seq(Row(1,1), Row(1,2), Row(2,1), Row(2,2), Row(3,1), Row(3,2)))
checkAnswer(
testData2.orderBy('a.asc, 'b.desc),
- Seq((1,2), (1,1), (2,2), (2,1), (3,2), (3,1)))
+ Seq(Row(1,2), Row(1,1), Row(2,2), Row(2,1), Row(3,2), Row(3,1)))
checkAnswer(
testData2.orderBy('a.desc, 'b.desc),
- Seq((3,2), (3,1), (2,2), (2,1), (1,2), (1,1)))
+ Seq(Row(3,2), Row(3,1), Row(2,2), Row(2,1), Row(1,2), Row(1,1)))
checkAnswer(
testData2.orderBy('a.desc, 'b.asc),
- Seq((3,1), (3,2), (2,1), (2,2), (1,1), (1,2)))
+ Seq(Row(3,1), Row(3,2), Row(2,1), Row(2,2), Row(1,1), Row(1,2)))
checkAnswer(
arrayData.orderBy('data.getItem(0).asc),
- arrayData.collect().sortBy(_.data(0)).toSeq)
+ arrayData.toSchemaRDD.collect().sortBy(_.getAs[Seq[Int]](0)(0)).toSeq)
checkAnswer(
arrayData.orderBy('data.getItem(0).desc),
- arrayData.collect().sortBy(_.data(0)).reverse.toSeq)
+ arrayData.toSchemaRDD.collect().sortBy(_.getAs[Seq[Int]](0)(0)).reverse.toSeq)
checkAnswer(
- mapData.orderBy('data.getItem(1).asc),
- mapData.collect().sortBy(_.data(1)).toSeq)
+ arrayData.orderBy('data.getItem(1).asc),
+ arrayData.toSchemaRDD.collect().sortBy(_.getAs[Seq[Int]](0)(1)).toSeq)
checkAnswer(
- mapData.orderBy('data.getItem(1).desc),
- mapData.collect().sortBy(_.data(1)).reverse.toSeq)
+ arrayData.orderBy('data.getItem(1).desc),
+ arrayData.toSchemaRDD.collect().sortBy(_.getAs[Seq[Int]](0)(1)).reverse.toSeq)
}
test("partition wide sorting") {
@@ -147,19 +147,19 @@ class DslQuerySuite extends QueryTest {
// (3, 2)
checkAnswer(
testData2.sortBy('a.asc, 'b.asc),
- Seq((1,1), (1,2), (2,1), (2,2), (3,1), (3,2)))
+ Seq(Row(1,1), Row(1,2), Row(2,1), Row(2,2), Row(3,1), Row(3,2)))
checkAnswer(
testData2.sortBy('a.asc, 'b.desc),
- Seq((1,2), (1,1), (2,1), (2,2), (3,2), (3,1)))
+ Seq(Row(1,2), Row(1,1), Row(2,1), Row(2,2), Row(3,2), Row(3,1)))
checkAnswer(
testData2.sortBy('a.desc, 'b.desc),
- Seq((2,1), (1,2), (1,1), (3,2), (3,1), (2,2)))
+ Seq(Row(2,1), Row(1,2), Row(1,1), Row(3,2), Row(3,1), Row(2,2)))
checkAnswer(
testData2.sortBy('a.desc, 'b.asc),
- Seq((2,1), (1,1), (1,2), (3,1), (3,2), (2,2)))
+ Seq(Row(2,1), Row(1,1), Row(1,2), Row(3,1), Row(3,2), Row(2,2)))
}
test("limit") {
@@ -169,11 +169,11 @@ class DslQuerySuite extends QueryTest {
checkAnswer(
arrayData.limit(1),
- arrayData.take(1).toSeq)
+ arrayData.take(1).map(r => Row.fromSeq(r.productIterator.toSeq)))
checkAnswer(
mapData.limit(1),
- mapData.take(1).toSeq)
+ mapData.take(1).map(r => Row.fromSeq(r.productIterator.toSeq)))
}
test("SPARK-3395 limit distinct") {
@@ -184,8 +184,8 @@ class DslQuerySuite extends QueryTest {
.registerTempTable("onerow")
checkAnswer(
sql("select * from onerow inner join testData2 on onerow.a = testData2.a"),
- (1, 1, 1, 1) ::
- (1, 1, 1, 2) :: Nil)
+ Row(1, 1, 1, 1) ::
+ Row(1, 1, 1, 2) :: Nil)
}
test("SPARK-3858 generator qualifiers are discarded") {
@@ -193,55 +193,55 @@ class DslQuerySuite extends QueryTest {
arrayData.as('ad)
.generate(Explode("data" :: Nil, 'data), alias = Some("ex"))
.select("ex.data".attr),
- Seq(1, 2, 3, 2, 3, 4).map(Seq(_)))
+ Seq(1, 2, 3, 2, 3, 4).map(Row(_)))
}
test("average") {
checkAnswer(
testData2.aggregate(avg('a)),
- 2.0)
+ Row(2.0))
checkAnswer(
testData2.aggregate(avg('a), sumDistinct('a)), // non-partial
- (2.0, 6.0) :: Nil)
+ Row(2.0, 6.0) :: Nil)
checkAnswer(
decimalData.aggregate(avg('a)),
- new java.math.BigDecimal(2.0))
+ Row(new java.math.BigDecimal(2.0)))
checkAnswer(
decimalData.aggregate(avg('a), sumDistinct('a)), // non-partial
- (new java.math.BigDecimal(2.0), new java.math.BigDecimal(6)) :: Nil)
+ Row(new java.math.BigDecimal(2.0), new java.math.BigDecimal(6)) :: Nil)
checkAnswer(
decimalData.aggregate(avg('a cast DecimalType(10, 2))),
- new java.math.BigDecimal(2.0))
+ Row(new java.math.BigDecimal(2.0)))
checkAnswer(
decimalData.aggregate(avg('a cast DecimalType(10, 2)), sumDistinct('a cast DecimalType(10, 2))), // non-partial
- (new java.math.BigDecimal(2.0), new java.math.BigDecimal(6)) :: Nil)
+ Row(new java.math.BigDecimal(2.0), new java.math.BigDecimal(6)) :: Nil)
}
test("null average") {
checkAnswer(
testData3.aggregate(avg('b)),
- 2.0)
+ Row(2.0))
checkAnswer(
testData3.aggregate(avg('b), countDistinct('b)),
- (2.0, 1) :: Nil)
+ Row(2.0, 1))
checkAnswer(
testData3.aggregate(avg('b), sumDistinct('b)), // non-partial
- (2.0, 2.0) :: Nil)
+ Row(2.0, 2.0))
}
test("zero average") {
checkAnswer(
emptyTableData.aggregate(avg('a)),
- null)
+ Row(null))
checkAnswer(
emptyTableData.aggregate(avg('a), sumDistinct('b)), // non-partial
- (null, null) :: Nil)
+ Row(null, null))
}
test("count") {
@@ -249,28 +249,28 @@ class DslQuerySuite extends QueryTest {
checkAnswer(
testData2.aggregate(count('a), sumDistinct('a)), // non-partial
- (6, 6.0) :: Nil)
+ Row(6, 6.0))
}
test("null count") {
checkAnswer(
testData3.groupBy('a)('a, count('b)),
- Seq((1,0), (2, 1))
+ Seq(Row(1,0), Row(2, 1))
)
checkAnswer(
testData3.groupBy('a)('a, count('a + 'b)),
- Seq((1,0), (2, 1))
+ Seq(Row(1,0), Row(2, 1))
)
checkAnswer(
testData3.aggregate(count('a), count('b), count(1), countDistinct('a), countDistinct('b)),
- (2, 1, 2, 2, 1) :: Nil
+ Row(2, 1, 2, 2, 1)
)
checkAnswer(
testData3.aggregate(count('b), countDistinct('b), sumDistinct('b)), // non-partial
- (1, 1, 2) :: Nil
+ Row(1, 1, 2)
)
}
@@ -279,28 +279,28 @@ class DslQuerySuite extends QueryTest {
checkAnswer(
emptyTableData.aggregate(count('a), sumDistinct('a)), // non-partial
- (0, null) :: Nil)
+ Row(0, null))
}
test("zero sum") {
checkAnswer(
emptyTableData.aggregate(sum('a)),
- null)
+ Row(null))
}
test("zero sum distinct") {
checkAnswer(
emptyTableData.aggregate(sumDistinct('a)),
- null)
+ Row(null))
}
test("except") {
checkAnswer(
lowerCaseData.except(upperCaseData),
- (1, "a") ::
- (2, "b") ::
- (3, "c") ::
- (4, "d") :: Nil)
+ Row(1, "a") ::
+ Row(2, "b") ::
+ Row(3, "c") ::
+ Row(4, "d") :: Nil)
checkAnswer(lowerCaseData.except(lowerCaseData), Nil)
checkAnswer(upperCaseData.except(upperCaseData), Nil)
}
@@ -308,10 +308,10 @@ class DslQuerySuite extends QueryTest {
test("intersect") {
checkAnswer(
lowerCaseData.intersect(lowerCaseData),
- (1, "a") ::
- (2, "b") ::
- (3, "c") ::
- (4, "d") :: Nil)
+ Row(1, "a") ::
+ Row(2, "b") ::
+ Row(3, "c") ::
+ Row(4, "d") :: Nil)
checkAnswer(lowerCaseData.intersect(upperCaseData), Nil)
}
@@ -321,75 +321,75 @@ class DslQuerySuite extends QueryTest {
checkAnswer(
// SELECT *, foo(key, value) FROM testData
testData.select(Star(None), foo.call('key, 'value)).limit(3),
- (1, "1", "11") :: (2, "2", "22") :: (3, "3", "33") :: Nil
+ Row(1, "1", "11") :: Row(2, "2", "22") :: Row(3, "3", "33") :: Nil
)
}
test("sqrt") {
checkAnswer(
testData.select(sqrt('key)).orderBy('key asc),
- (1 to 100).map(n => Seq(math.sqrt(n)))
+ (1 to 100).map(n => Row(math.sqrt(n)))
)
checkAnswer(
testData.select(sqrt('value), 'key).orderBy('key asc, 'value asc),
- (1 to 100).map(n => Seq(math.sqrt(n), n))
+ (1 to 100).map(n => Row(math.sqrt(n), n))
)
checkAnswer(
testData.select(sqrt(Literal(null))),
- (1 to 100).map(_ => Seq(null))
+ (1 to 100).map(_ => Row(null))
)
}
test("abs") {
checkAnswer(
testData.select(abs('key)).orderBy('key asc),
- (1 to 100).map(n => Seq(n))
+ (1 to 100).map(n => Row(n))
)
checkAnswer(
negativeData.select(abs('key)).orderBy('key desc),
- (1 to 100).map(n => Seq(n))
+ (1 to 100).map(n => Row(n))
)
checkAnswer(
testData.select(abs(Literal(null))),
- (1 to 100).map(_ => Seq(null))
+ (1 to 100).map(_ => Row(null))
)
}
test("upper") {
checkAnswer(
lowerCaseData.select(upper('l)),
- ('a' to 'd').map(c => Seq(c.toString.toUpperCase()))
+ ('a' to 'd').map(c => Row(c.toString.toUpperCase()))
)
checkAnswer(
testData.select(upper('value), 'key),
- (1 to 100).map(n => Seq(n.toString, n))
+ (1 to 100).map(n => Row(n.toString, n))
)
checkAnswer(
testData.select(upper(Literal(null))),
- (1 to 100).map(n => Seq(null))
+ (1 to 100).map(n => Row(null))
)
}
test("lower") {
checkAnswer(
upperCaseData.select(lower('L)),
- ('A' to 'F').map(c => Seq(c.toString.toLowerCase()))
+ ('A' to 'F').map(c => Row(c.toString.toLowerCase()))
)
checkAnswer(
testData.select(lower('value), 'key),
- (1 to 100).map(n => Seq(n.toString, n))
+ (1 to 100).map(n => Row(n.toString, n))
)
checkAnswer(
testData.select(lower(Literal(null))),
- (1 to 100).map(n => Seq(null))
+ (1 to 100).map(n => Row(null))
)
}
}
http://git-wip-us.apache.org/repos/asf/spark/blob/d181c2a1/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
index e5ab16f..cd36da7 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
@@ -117,10 +117,10 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach {
checkAnswer(
upperCaseData.join(lowerCaseData, Inner).where('n === 'N),
Seq(
- (1, "A", 1, "a"),
- (2, "B", 2, "b"),
- (3, "C", 3, "c"),
- (4, "D", 4, "d")
+ Row(1, "A", 1, "a"),
+ Row(2, "B", 2, "b"),
+ Row(3, "C", 3, "c"),
+ Row(4, "D", 4, "d")
))
}
@@ -128,10 +128,10 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach {
checkAnswer(
upperCaseData.join(lowerCaseData, Inner, Some('n === 'N)),
Seq(
- (1, "A", 1, "a"),
- (2, "B", 2, "b"),
- (3, "C", 3, "c"),
- (4, "D", 4, "d")
+ Row(1, "A", 1, "a"),
+ Row(2, "B", 2, "b"),
+ Row(3, "C", 3, "c"),
+ Row(4, "D", 4, "d")
))
}
@@ -140,10 +140,10 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach {
val y = testData2.where('a === 1).as('y)
checkAnswer(
x.join(y).where("x.a".attr === "y.a".attr),
- (1,1,1,1) ::
- (1,1,1,2) ::
- (1,2,1,1) ::
- (1,2,1,2) :: Nil
+ Row(1,1,1,1) ::
+ Row(1,1,1,2) ::
+ Row(1,2,1,1) ::
+ Row(1,2,1,2) :: Nil
)
}
@@ -163,54 +163,54 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach {
checkAnswer(
bigDataX.join(bigDataY).where("x.key".attr === "y.key".attr),
testData.flatMap(
- row => Seq.fill(16)((row ++ row).toSeq)).collect().toSeq)
+ row => Seq.fill(16)(Row.merge(row, row))).collect().toSeq)
}
test("cartisian product join") {
checkAnswer(
testData3.join(testData3),
- (1, null, 1, null) ::
- (1, null, 2, 2) ::
- (2, 2, 1, null) ::
- (2, 2, 2, 2) :: Nil)
+ Row(1, null, 1, null) ::
+ Row(1, null, 2, 2) ::
+ Row(2, 2, 1, null) ::
+ Row(2, 2, 2, 2) :: Nil)
}
test("left outer join") {
checkAnswer(
upperCaseData.join(lowerCaseData, LeftOuter, Some('n === 'N)),
- (1, "A", 1, "a") ::
- (2, "B", 2, "b") ::
- (3, "C", 3, "c") ::
- (4, "D", 4, "d") ::
- (5, "E", null, null) ::
- (6, "F", null, null) :: Nil)
+ Row(1, "A", 1, "a") ::
+ Row(2, "B", 2, "b") ::
+ Row(3, "C", 3, "c") ::
+ Row(4, "D", 4, "d") ::
+ Row(5, "E", null, null) ::
+ Row(6, "F", null, null) :: Nil)
checkAnswer(
upperCaseData.join(lowerCaseData, LeftOuter, Some('n === 'N && 'n > 1)),
- (1, "A", null, null) ::
- (2, "B", 2, "b") ::
- (3, "C", 3, "c") ::
- (4, "D", 4, "d") ::
- (5, "E", null, null) ::
- (6, "F", null, null) :: Nil)
+ Row(1, "A", null, null) ::
+ Row(2, "B", 2, "b") ::
+ Row(3, "C", 3, "c") ::
+ Row(4, "D", 4, "d") ::
+ Row(5, "E", null, null) ::
+ Row(6, "F", null, null) :: Nil)
checkAnswer(
upperCaseData.join(lowerCaseData, LeftOuter, Some('n === 'N && 'N > 1)),
- (1, "A", null, null) ::
- (2, "B", 2, "b") ::
- (3, "C", 3, "c") ::
- (4, "D", 4, "d") ::
- (5, "E", null, null) ::
- (6, "F", null, null) :: Nil)
+ Row(1, "A", null, null) ::
+ Row(2, "B", 2, "b") ::
+ Row(3, "C", 3, "c") ::
+ Row(4, "D", 4, "d") ::
+ Row(5, "E", null, null) ::
+ Row(6, "F", null, null) :: Nil)
checkAnswer(
upperCaseData.join(lowerCaseData, LeftOuter, Some('n === 'N && 'l > 'L)),
- (1, "A", 1, "a") ::
- (2, "B", 2, "b") ::
- (3, "C", 3, "c") ::
- (4, "D", 4, "d") ::
- (5, "E", null, null) ::
- (6, "F", null, null) :: Nil)
+ Row(1, "A", 1, "a") ::
+ Row(2, "B", 2, "b") ::
+ Row(3, "C", 3, "c") ::
+ Row(4, "D", 4, "d") ::
+ Row(5, "E", null, null) ::
+ Row(6, "F", null, null) :: Nil)
// Make sure we are choosing left.outputPartitioning as the
// outputPartitioning for the outer join operator.
@@ -221,12 +221,12 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach {
|FROM upperCaseData l LEFT OUTER JOIN allNulls r ON (l.N = r.a)
|GROUP BY l.N
""".stripMargin),
- (1, 1) ::
- (2, 1) ::
- (3, 1) ::
- (4, 1) ::
- (5, 1) ::
- (6, 1) :: Nil)
+ Row(1, 1) ::
+ Row(2, 1) ::
+ Row(3, 1) ::
+ Row(4, 1) ::
+ Row(5, 1) ::
+ Row(6, 1) :: Nil)
checkAnswer(
sql(
@@ -235,42 +235,42 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach {
|FROM upperCaseData l LEFT OUTER JOIN allNulls r ON (l.N = r.a)
|GROUP BY r.a
""".stripMargin),
- (null, 6) :: Nil)
+ Row(null, 6) :: Nil)
}
test("right outer join") {
checkAnswer(
lowerCaseData.join(upperCaseData, RightOuter, Some('n === 'N)),
- (1, "a", 1, "A") ::
- (2, "b", 2, "B") ::
- (3, "c", 3, "C") ::
- (4, "d", 4, "D") ::
- (null, null, 5, "E") ::
- (null, null, 6, "F") :: Nil)
+ Row(1, "a", 1, "A") ::
+ Row(2, "b", 2, "B") ::
+ Row(3, "c", 3, "C") ::
+ Row(4, "d", 4, "D") ::
+ Row(null, null, 5, "E") ::
+ Row(null, null, 6, "F") :: Nil)
checkAnswer(
lowerCaseData.join(upperCaseData, RightOuter, Some('n === 'N && 'n > 1)),
- (null, null, 1, "A") ::
- (2, "b", 2, "B") ::
- (3, "c", 3, "C") ::
- (4, "d", 4, "D") ::
- (null, null, 5, "E") ::
- (null, null, 6, "F") :: Nil)
+ Row(null, null, 1, "A") ::
+ Row(2, "b", 2, "B") ::
+ Row(3, "c", 3, "C") ::
+ Row(4, "d", 4, "D") ::
+ Row(null, null, 5, "E") ::
+ Row(null, null, 6, "F") :: Nil)
checkAnswer(
lowerCaseData.join(upperCaseData, RightOuter, Some('n === 'N && 'N > 1)),
- (null, null, 1, "A") ::
- (2, "b", 2, "B") ::
- (3, "c", 3, "C") ::
- (4, "d", 4, "D") ::
- (null, null, 5, "E") ::
- (null, null, 6, "F") :: Nil)
+ Row(null, null, 1, "A") ::
+ Row(2, "b", 2, "B") ::
+ Row(3, "c", 3, "C") ::
+ Row(4, "d", 4, "D") ::
+ Row(null, null, 5, "E") ::
+ Row(null, null, 6, "F") :: Nil)
checkAnswer(
lowerCaseData.join(upperCaseData, RightOuter, Some('n === 'N && 'l > 'L)),
- (1, "a", 1, "A") ::
- (2, "b", 2, "B") ::
- (3, "c", 3, "C") ::
- (4, "d", 4, "D") ::
- (null, null, 5, "E") ::
- (null, null, 6, "F") :: Nil)
+ Row(1, "a", 1, "A") ::
+ Row(2, "b", 2, "B") ::
+ Row(3, "c", 3, "C") ::
+ Row(4, "d", 4, "D") ::
+ Row(null, null, 5, "E") ::
+ Row(null, null, 6, "F") :: Nil)
// Make sure we are choosing right.outputPartitioning as the
// outputPartitioning for the outer join operator.
@@ -281,7 +281,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach {
|FROM allNulls l RIGHT OUTER JOIN upperCaseData r ON (l.a = r.N)
|GROUP BY l.a
""".stripMargin),
- (null, 6) :: Nil)
+ Row(null, 6))
checkAnswer(
sql(
@@ -290,12 +290,12 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach {
|FROM allNulls l RIGHT OUTER JOIN upperCaseData r ON (l.a = r.N)
|GROUP BY r.N
""".stripMargin),
- (1, 1) ::
- (2, 1) ::
- (3, 1) ::
- (4, 1) ::
- (5, 1) ::
- (6, 1) :: Nil)
+ Row(1, 1) ::
+ Row(2, 1) ::
+ Row(3, 1) ::
+ Row(4, 1) ::
+ Row(5, 1) ::
+ Row(6, 1) :: Nil)
}
test("full outer join") {
@@ -307,32 +307,32 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach {
checkAnswer(
left.join(right, FullOuter, Some("left.N".attr === "right.N".attr)),
- (1, "A", null, null) ::
- (2, "B", null, null) ::
- (3, "C", 3, "C") ::
- (4, "D", 4, "D") ::
- (null, null, 5, "E") ::
- (null, null, 6, "F") :: Nil)
+ Row(1, "A", null, null) ::
+ Row(2, "B", null, null) ::
+ Row(3, "C", 3, "C") ::
+ Row(4, "D", 4, "D") ::
+ Row(null, null, 5, "E") ::
+ Row(null, null, 6, "F") :: Nil)
checkAnswer(
left.join(right, FullOuter, Some(("left.N".attr === "right.N".attr) && ("left.N".attr !== 3))),
- (1, "A", null, null) ::
- (2, "B", null, null) ::
- (3, "C", null, null) ::
- (null, null, 3, "C") ::
- (4, "D", 4, "D") ::
- (null, null, 5, "E") ::
- (null, null, 6, "F") :: Nil)
+ Row(1, "A", null, null) ::
+ Row(2, "B", null, null) ::
+ Row(3, "C", null, null) ::
+ Row(null, null, 3, "C") ::
+ Row(4, "D", 4, "D") ::
+ Row(null, null, 5, "E") ::
+ Row(null, null, 6, "F") :: Nil)
checkAnswer(
left.join(right, FullOuter, Some(("left.N".attr === "right.N".attr) && ("right.N".attr !== 3))),
- (1, "A", null, null) ::
- (2, "B", null, null) ::
- (3, "C", null, null) ::
- (null, null, 3, "C") ::
- (4, "D", 4, "D") ::
- (null, null, 5, "E") ::
- (null, null, 6, "F") :: Nil)
+ Row(1, "A", null, null) ::
+ Row(2, "B", null, null) ::
+ Row(3, "C", null, null) ::
+ Row(null, null, 3, "C") ::
+ Row(4, "D", 4, "D") ::
+ Row(null, null, 5, "E") ::
+ Row(null, null, 6, "F") :: Nil)
// Make sure we are UnknownPartitioning as the outputPartitioning for the outer join operator.
checkAnswer(
@@ -342,7 +342,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach {
|FROM allNulls l FULL OUTER JOIN upperCaseData r ON (l.a = r.N)
|GROUP BY l.a
""".stripMargin),
- (null, 10) :: Nil)
+ Row(null, 10))
checkAnswer(
sql(
@@ -351,13 +351,13 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach {
|FROM allNulls l FULL OUTER JOIN upperCaseData r ON (l.a = r.N)
|GROUP BY r.N
""".stripMargin),
- (1, 1) ::
- (2, 1) ::
- (3, 1) ::
- (4, 1) ::
- (5, 1) ::
- (6, 1) ::
- (null, 4) :: Nil)
+ Row(1, 1) ::
+ Row(2, 1) ::
+ Row(3, 1) ::
+ Row(4, 1) ::
+ Row(5, 1) ::
+ Row(6, 1) ::
+ Row(null, 4) :: Nil)
checkAnswer(
sql(
@@ -366,13 +366,13 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach {
|FROM upperCaseData l FULL OUTER JOIN allNulls r ON (l.N = r.a)
|GROUP BY l.N
""".stripMargin),
- (1, 1) ::
- (2, 1) ::
- (3, 1) ::
- (4, 1) ::
- (5, 1) ::
- (6, 1) ::
- (null, 4) :: Nil)
+ Row(1, 1) ::
+ Row(2, 1) ::
+ Row(3, 1) ::
+ Row(4, 1) ::
+ Row(5, 1) ::
+ Row(6, 1) ::
+ Row(null, 4) :: Nil)
checkAnswer(
sql(
@@ -381,7 +381,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach {
|FROM upperCaseData l FULL OUTER JOIN allNulls r ON (l.N = r.a)
|GROUP BY r.a
""".stripMargin),
- (null, 10) :: Nil)
+ Row(null, 10))
}
test("broadcasted left semi join operator selection") {
@@ -412,12 +412,12 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach {
test("left semi join") {
val rdd = sql("SELECT * FROM testData2 LEFT SEMI JOIN testData ON key = a")
checkAnswer(rdd,
- (1, 1) ::
- (1, 2) ::
- (2, 1) ::
- (2, 2) ::
- (3, 1) ::
- (3, 2) :: Nil)
+ Row(1, 1) ::
+ Row(1, 2) ::
+ Row(2, 1) ::
+ Row(2, 2) ::
+ Row(3, 1) ::
+ Row(3, 2) :: Nil)
}
}
http://git-wip-us.apache.org/repos/asf/spark/blob/d181c2a1/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
index 68ddecc..42a21c1 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
@@ -47,26 +47,17 @@ class QueryTest extends PlanTest {
* @param rdd the [[SchemaRDD]] to be executed
* @param expectedAnswer the expected result, can either be an Any, Seq[Product], or Seq[ Seq[Any] ].
*/
- protected def checkAnswer(rdd: SchemaRDD, expectedAnswer: Any): Unit = {
- val convertedAnswer = expectedAnswer match {
- case s: Seq[_] if s.isEmpty => s
- case s: Seq[_] if s.head.isInstanceOf[Product] &&
- !s.head.isInstanceOf[Seq[_]] => s.map(_.asInstanceOf[Product].productIterator.toIndexedSeq)
- case s: Seq[_] => s
- case singleItem => Seq(Seq(singleItem))
- }
-
+ protected def checkAnswer(rdd: SchemaRDD, expectedAnswer: Seq[Row]): Unit = {
val isSorted = rdd.logicalPlan.collect { case s: logical.Sort => s }.nonEmpty
- def prepareAnswer(answer: Seq[Any]): Seq[Any] = {
+ def prepareAnswer(answer: Seq[Row]): Seq[Row] = {
// Converts data to types that we can do equality comparison using Scala collections.
// For BigDecimal type, the Scala type has a better definition of equality test (similar to
// Java's java.math.BigDecimal.compareTo).
- val converted = answer.map {
- case s: Seq[_] => s.map {
+ val converted: Seq[Row] = answer.map { s =>
+ Row.fromSeq(s.toSeq.map {
case d: java.math.BigDecimal => BigDecimal(d)
case o => o
- }
- case o => o
+ })
}
if (!isSorted) converted.sortBy(_.toString) else converted
}
@@ -82,7 +73,7 @@ class QueryTest extends PlanTest {
""".stripMargin)
}
- if (prepareAnswer(convertedAnswer) != prepareAnswer(sparkAnswer)) {
+ if (prepareAnswer(expectedAnswer) != prepareAnswer(sparkAnswer)) {
fail(s"""
|Results do not match for query:
|${rdd.logicalPlan}
@@ -92,15 +83,19 @@ class QueryTest extends PlanTest {
|${rdd.queryExecution.executedPlan}
|== Results ==
|${sideBySide(
- s"== Correct Answer - ${convertedAnswer.size} ==" +:
- prepareAnswer(convertedAnswer).map(_.toString),
+ s"== Correct Answer - ${expectedAnswer.size} ==" +:
+ prepareAnswer(expectedAnswer).map(_.toString),
s"== Spark Answer - ${sparkAnswer.size} ==" +:
prepareAnswer(sparkAnswer).map(_.toString)).mkString("\n")}
""".stripMargin)
}
}
- def sqlTest(sqlString: String, expectedAnswer: Any)(implicit sqlContext: SQLContext): Unit = {
+ protected def checkAnswer(rdd: SchemaRDD, expectedAnswer: Row): Unit = {
+ checkAnswer(rdd, Seq(expectedAnswer))
+ }
+
+ def sqlTest(sqlString: String, expectedAnswer: Seq[Row])(implicit sqlContext: SQLContext): Unit = {
test(sqlString) {
checkAnswer(sqlContext.sql(sqlString), expectedAnswer)
}
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org