You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by rx...@apache.org on 2015/07/31 02:17:30 UTC
spark git commit: [SPARK-9458][SPARK-9469][SQL] Code generate prefix
computation in sorting & moves unsafe conversion out of TungstenSort.
Repository: spark
Updated Branches:
refs/heads/master df3266951 -> e7a0976e9
[SPARK-9458][SPARK-9469][SQL] Code generate prefix computation in sorting & moves unsafe conversion out of TungstenSort.
Author: Reynold Xin <rx...@databricks.com>
Closes #7803 from rxin/SPARK-9458 and squashes the following commits:
5b032dc [Reynold Xin] Fix string.
b670dbb [Reynold Xin] [SPARK-9458][SPARK-9469][SQL] Code generate prefix computation in sorting & moves unsafe conversion out of TungstenSort.
Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/e7a0976e
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/e7a0976e
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/e7a0976e
Branch: refs/heads/master
Commit: e7a0976e991f75a7bda99509e2b040daab965ae6
Parents: df32669
Author: Reynold Xin <rx...@databricks.com>
Authored: Thu Jul 30 17:17:27 2015 -0700
Committer: Reynold Xin <rx...@databricks.com>
Committed: Thu Jul 30 17:17:27 2015 -0700
----------------------------------------------------------------------
.../unsafe/sort/PrefixComparators.java | 49 +++++++-----
.../unsafe/sort/PrefixComparatorsSuite.scala | 22 ++----
.../sql/execution/UnsafeExternalRowSorter.java | 27 +++----
.../sql/catalyst/expressions/SortOrder.scala | 44 ++++++++++-
.../spark/sql/execution/SortPrefixUtils.scala | 64 ++-------------
.../spark/sql/execution/SparkStrategies.scala | 4 +-
.../sql/execution/joins/HashedRelation.scala | 4 +-
.../org/apache/spark/sql/execution/sort.scala | 64 +++++++--------
.../execution/RowFormatConvertersSuite.scala | 11 ++-
.../spark/sql/execution/TungstenSortSuite.scala | 83 ++++++++++++++++++++
.../sql/execution/UnsafeExternalSortSuite.scala | 83 --------------------
11 files changed, 216 insertions(+), 239 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/spark/blob/e7a0976e/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java
index 600aff7..4d7e5b3 100644
--- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java
+++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java
@@ -28,9 +28,11 @@ public class PrefixComparators {
private PrefixComparators() {}
public static final StringPrefixComparator STRING = new StringPrefixComparator();
- public static final IntegralPrefixComparator INTEGRAL = new IntegralPrefixComparator();
- public static final FloatPrefixComparator FLOAT = new FloatPrefixComparator();
+ public static final StringPrefixComparatorDesc STRING_DESC = new StringPrefixComparatorDesc();
+ public static final LongPrefixComparator LONG = new LongPrefixComparator();
+ public static final LongPrefixComparatorDesc LONG_DESC = new LongPrefixComparatorDesc();
public static final DoublePrefixComparator DOUBLE = new DoublePrefixComparator();
+ public static final DoublePrefixComparatorDesc DOUBLE_DESC = new DoublePrefixComparatorDesc();
public static final class StringPrefixComparator extends PrefixComparator {
@Override
@@ -38,50 +40,55 @@ public class PrefixComparators {
return UnsignedLongs.compare(aPrefix, bPrefix);
}
- public long computePrefix(UTF8String value) {
+ public static long computePrefix(UTF8String value) {
return value == null ? 0L : value.getPrefix();
}
}
- /**
- * Prefix comparator for all integral types (boolean, byte, short, int, long).
- */
- public static final class IntegralPrefixComparator extends PrefixComparator {
+ public static final class StringPrefixComparatorDesc extends PrefixComparator {
+ @Override
+ public int compare(long bPrefix, long aPrefix) {
+ return UnsignedLongs.compare(aPrefix, bPrefix);
+ }
+ }
+
+ public static final class LongPrefixComparator extends PrefixComparator {
@Override
public int compare(long a, long b) {
return (a < b) ? -1 : (a > b) ? 1 : 0;
}
+ }
- public final long NULL_PREFIX = Long.MIN_VALUE;
+ public static final class LongPrefixComparatorDesc extends PrefixComparator {
+ @Override
+ public int compare(long b, long a) {
+ return (a < b) ? -1 : (a > b) ? 1 : 0;
+ }
}
- public static final class FloatPrefixComparator extends PrefixComparator {
+ public static final class DoublePrefixComparator extends PrefixComparator {
@Override
public int compare(long aPrefix, long bPrefix) {
- float a = Float.intBitsToFloat((int) aPrefix);
- float b = Float.intBitsToFloat((int) bPrefix);
- return Utils.nanSafeCompareFloats(a, b);
+ double a = Double.longBitsToDouble(aPrefix);
+ double b = Double.longBitsToDouble(bPrefix);
+ return Utils.nanSafeCompareDoubles(a, b);
}
- public long computePrefix(float value) {
- return Float.floatToIntBits(value) & 0xffffffffL;
+ public static long computePrefix(double value) {
+ return Double.doubleToLongBits(value);
}
-
- public final long NULL_PREFIX = computePrefix(Float.NEGATIVE_INFINITY);
}
- public static final class DoublePrefixComparator extends PrefixComparator {
+ public static final class DoublePrefixComparatorDesc extends PrefixComparator {
@Override
- public int compare(long aPrefix, long bPrefix) {
+ public int compare(long bPrefix, long aPrefix) {
double a = Double.longBitsToDouble(aPrefix);
double b = Double.longBitsToDouble(bPrefix);
return Utils.nanSafeCompareDoubles(a, b);
}
- public long computePrefix(double value) {
+ public static long computePrefix(double value) {
return Double.doubleToLongBits(value);
}
-
- public final long NULL_PREFIX = computePrefix(Double.NEGATIVE_INFINITY);
}
}
http://git-wip-us.apache.org/repos/asf/spark/blob/e7a0976e/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/PrefixComparatorsSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/PrefixComparatorsSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/PrefixComparatorsSuite.scala
index cf53a8a..26a2e96 100644
--- a/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/PrefixComparatorsSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/PrefixComparatorsSuite.scala
@@ -29,8 +29,8 @@ class PrefixComparatorsSuite extends SparkFunSuite with PropertyChecks {
def testPrefixComparison(s1: String, s2: String): Unit = {
val utf8string1 = UTF8String.fromString(s1)
val utf8string2 = UTF8String.fromString(s2)
- val s1Prefix = PrefixComparators.STRING.computePrefix(utf8string1)
- val s2Prefix = PrefixComparators.STRING.computePrefix(utf8string2)
+ val s1Prefix = PrefixComparators.StringPrefixComparator.computePrefix(utf8string1)
+ val s2Prefix = PrefixComparators.StringPrefixComparator.computePrefix(utf8string2)
val prefixComparisonResult = PrefixComparators.STRING.compare(s1Prefix, s2Prefix)
val cmp = UnsignedBytes.lexicographicalComparator().compare(
@@ -55,27 +55,15 @@ class PrefixComparatorsSuite extends SparkFunSuite with PropertyChecks {
forAll { (s1: String, s2: String) => testPrefixComparison(s1, s2) }
}
- test("float prefix comparator handles NaN properly") {
- val nan1: Float = java.lang.Float.intBitsToFloat(0x7f800001)
- val nan2: Float = java.lang.Float.intBitsToFloat(0x7fffffff)
- assert(nan1.isNaN)
- assert(nan2.isNaN)
- val nan1Prefix = PrefixComparators.FLOAT.computePrefix(nan1)
- val nan2Prefix = PrefixComparators.FLOAT.computePrefix(nan2)
- assert(nan1Prefix === nan2Prefix)
- val floatMaxPrefix = PrefixComparators.FLOAT.computePrefix(Float.MaxValue)
- assert(PrefixComparators.FLOAT.compare(nan1Prefix, floatMaxPrefix) === 1)
- }
-
test("double prefix comparator handles NaNs properly") {
val nan1: Double = java.lang.Double.longBitsToDouble(0x7ff0000000000001L)
val nan2: Double = java.lang.Double.longBitsToDouble(0x7fffffffffffffffL)
assert(nan1.isNaN)
assert(nan2.isNaN)
- val nan1Prefix = PrefixComparators.DOUBLE.computePrefix(nan1)
- val nan2Prefix = PrefixComparators.DOUBLE.computePrefix(nan2)
+ val nan1Prefix = PrefixComparators.DoublePrefixComparator.computePrefix(nan1)
+ val nan2Prefix = PrefixComparators.DoublePrefixComparator.computePrefix(nan2)
assert(nan1Prefix === nan2Prefix)
- val doubleMaxPrefix = PrefixComparators.DOUBLE.computePrefix(Double.MaxValue)
+ val doubleMaxPrefix = PrefixComparators.DoublePrefixComparator.computePrefix(Double.MaxValue)
assert(PrefixComparators.DOUBLE.compare(nan1Prefix, doubleMaxPrefix) === 1)
}
http://git-wip-us.apache.org/repos/asf/spark/blob/e7a0976e/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java
index 4c3f2c6..68c49fe 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java
@@ -48,7 +48,6 @@ final class UnsafeExternalRowSorter {
private long numRowsInserted = 0;
private final StructType schema;
- private final UnsafeProjection unsafeProjection;
private final PrefixComputer prefixComputer;
private final UnsafeExternalSorter sorter;
@@ -62,7 +61,6 @@ final class UnsafeExternalRowSorter {
PrefixComparator prefixComparator,
PrefixComputer prefixComputer) throws IOException {
this.schema = schema;
- this.unsafeProjection = UnsafeProjection.create(schema);
this.prefixComputer = prefixComputer;
final SparkEnv sparkEnv = SparkEnv.get();
final TaskContext taskContext = TaskContext.get();
@@ -88,13 +86,12 @@ final class UnsafeExternalRowSorter {
}
@VisibleForTesting
- void insertRow(InternalRow row) throws IOException {
- UnsafeRow unsafeRow = unsafeProjection.apply(row);
+ void insertRow(UnsafeRow row) throws IOException {
final long prefix = prefixComputer.computePrefix(row);
sorter.insertRecord(
- unsafeRow.getBaseObject(),
- unsafeRow.getBaseOffset(),
- unsafeRow.getSizeInBytes(),
+ row.getBaseObject(),
+ row.getBaseOffset(),
+ row.getSizeInBytes(),
prefix
);
numRowsInserted++;
@@ -113,7 +110,7 @@ final class UnsafeExternalRowSorter {
}
@VisibleForTesting
- Iterator<InternalRow> sort() throws IOException {
+ Iterator<UnsafeRow> sort() throws IOException {
try {
final UnsafeSorterIterator sortedIterator = sorter.getSortedIterator();
if (!sortedIterator.hasNext()) {
@@ -121,7 +118,7 @@ final class UnsafeExternalRowSorter {
// here in order to prevent memory leaks.
cleanupResources();
}
- return new AbstractScalaRowIterator() {
+ return new AbstractScalaRowIterator<UnsafeRow>() {
private final int numFields = schema.length();
private UnsafeRow row = new UnsafeRow();
@@ -132,7 +129,7 @@ final class UnsafeExternalRowSorter {
}
@Override
- public InternalRow next() {
+ public UnsafeRow next() {
try {
sortedIterator.loadNext();
row.pointTo(
@@ -164,11 +161,11 @@ final class UnsafeExternalRowSorter {
}
- public Iterator<InternalRow> sort(Iterator<InternalRow> inputIterator) throws IOException {
- while (inputIterator.hasNext()) {
- insertRow(inputIterator.next());
- }
- return sort();
+ public Iterator<UnsafeRow> sort(Iterator<UnsafeRow> inputIterator) throws IOException {
+ while (inputIterator.hasNext()) {
+ insertRow(inputIterator.next());
+ }
+ return sort();
}
/**
http://git-wip-us.apache.org/repos/asf/spark/blob/e7a0976e/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala
index 3f436c0..9fe877f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala
@@ -17,7 +17,10 @@
package org.apache.spark.sql.catalyst.expressions
-import org.apache.spark.sql.types.DataType
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, CodeGenContext}
+import org.apache.spark.sql.types._
+import org.apache.spark.util.collection.unsafe.sort.PrefixComparators.DoublePrefixComparator
abstract sealed class SortDirection
case object Ascending extends SortDirection
@@ -37,4 +40,43 @@ case class SortOrder(child: Expression, direction: SortDirection)
override def nullable: Boolean = child.nullable
override def toString: String = s"$child ${if (direction == Ascending) "ASC" else "DESC"}"
+
+ def isAscending: Boolean = direction == Ascending
+}
+
+/**
+ * An expression to generate a 64-bit long prefix used in sorting.
+ */
+case class SortPrefix(child: SortOrder) extends UnaryExpression {
+
+ override def eval(input: InternalRow): Any = throw new UnsupportedOperationException
+
+ override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
+ val childCode = child.child.gen(ctx)
+ val input = childCode.primitive
+ val DoublePrefixCmp = classOf[DoublePrefixComparator].getName
+
+ val (nullValue: Long, prefixCode: String) = child.child.dataType match {
+ case BooleanType =>
+ (Long.MinValue, s"$input ? 1L : 0L")
+ case _: IntegralType =>
+ (Long.MinValue, s"(long) $input")
+ case FloatType | DoubleType =>
+ (DoublePrefixComparator.computePrefix(Double.NegativeInfinity),
+ s"$DoublePrefixCmp.computePrefix((double)$input)")
+ case StringType => (0L, s"$input.getPrefix()")
+ case _ => (0L, "0L")
+ }
+
+ childCode.code +
+ s"""
+ |long ${ev.primitive} = ${nullValue}L;
+ |boolean ${ev.isNull} = false;
+ |if (!${childCode.isNull}) {
+ | ${ev.primitive} = $prefixCode;
+ |}
+ """.stripMargin
+ }
+
+ override def dataType: DataType = LongType
}
http://git-wip-us.apache.org/repos/asf/spark/blob/e7a0976e/sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala
index 2dee354..a2145b1 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala
@@ -18,10 +18,8 @@
package org.apache.spark.sql.execution
-import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.SortOrder
import org.apache.spark.sql.types._
-import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.util.collection.unsafe.sort.{PrefixComparators, PrefixComparator}
@@ -37,61 +35,15 @@ object SortPrefixUtils {
def getPrefixComparator(sortOrder: SortOrder): PrefixComparator = {
sortOrder.dataType match {
- case StringType => PrefixComparators.STRING
- case BooleanType | ByteType | ShortType | IntegerType | LongType => PrefixComparators.INTEGRAL
- case FloatType => PrefixComparators.FLOAT
- case DoubleType => PrefixComparators.DOUBLE
+ case StringType if sortOrder.isAscending => PrefixComparators.STRING
+ case StringType if !sortOrder.isAscending => PrefixComparators.STRING_DESC
+ case BooleanType | ByteType | ShortType | IntegerType | LongType if sortOrder.isAscending =>
+ PrefixComparators.LONG
+ case BooleanType | ByteType | ShortType | IntegerType | LongType if !sortOrder.isAscending =>
+ PrefixComparators.LONG_DESC
+ case FloatType | DoubleType if sortOrder.isAscending => PrefixComparators.DOUBLE
+ case FloatType | DoubleType if !sortOrder.isAscending => PrefixComparators.DOUBLE_DESC
case _ => NoOpPrefixComparator
}
}
-
- def getPrefixComputer(sortOrder: SortOrder): InternalRow => Long = {
- sortOrder.dataType match {
- case StringType => (row: InternalRow) => {
- PrefixComparators.STRING.computePrefix(sortOrder.child.eval(row).asInstanceOf[UTF8String])
- }
- case BooleanType =>
- (row: InternalRow) => {
- val exprVal = sortOrder.child.eval(row)
- if (exprVal == null) PrefixComparators.INTEGRAL.NULL_PREFIX
- else if (sortOrder.child.eval(row).asInstanceOf[Boolean]) 1
- else 0
- }
- case ByteType =>
- (row: InternalRow) => {
- val exprVal = sortOrder.child.eval(row)
- if (exprVal == null) PrefixComparators.INTEGRAL.NULL_PREFIX
- else sortOrder.child.eval(row).asInstanceOf[Byte]
- }
- case ShortType =>
- (row: InternalRow) => {
- val exprVal = sortOrder.child.eval(row)
- if (exprVal == null) PrefixComparators.INTEGRAL.NULL_PREFIX
- else sortOrder.child.eval(row).asInstanceOf[Short]
- }
- case IntegerType =>
- (row: InternalRow) => {
- val exprVal = sortOrder.child.eval(row)
- if (exprVal == null) PrefixComparators.INTEGRAL.NULL_PREFIX
- else sortOrder.child.eval(row).asInstanceOf[Int]
- }
- case LongType =>
- (row: InternalRow) => {
- val exprVal = sortOrder.child.eval(row)
- if (exprVal == null) PrefixComparators.INTEGRAL.NULL_PREFIX
- else sortOrder.child.eval(row).asInstanceOf[Long]
- }
- case FloatType => (row: InternalRow) => {
- val exprVal = sortOrder.child.eval(row)
- if (exprVal == null) PrefixComparators.FLOAT.NULL_PREFIX
- else PrefixComparators.FLOAT.computePrefix(sortOrder.child.eval(row).asInstanceOf[Float])
- }
- case DoubleType => (row: InternalRow) => {
- val exprVal = sortOrder.child.eval(row)
- if (exprVal == null) PrefixComparators.DOUBLE.NULL_PREFIX
- else PrefixComparators.DOUBLE.computePrefix(sortOrder.child.eval(row).asInstanceOf[Double])
- }
- case _ => (row: InternalRow) => 0L
- }
- }
}
http://git-wip-us.apache.org/repos/asf/spark/blob/e7a0976e/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index 52a9b02..03d24a8 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -341,8 +341,8 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
*/
def getSortOperator(sortExprs: Seq[SortOrder], global: Boolean, child: SparkPlan): SparkPlan = {
if (sqlContext.conf.unsafeEnabled && sqlContext.conf.codegenEnabled &&
- UnsafeExternalSort.supportsSchema(child.schema)) {
- execution.UnsafeExternalSort(sortExprs, global, child)
+ TungstenSort.supportsSchema(child.schema)) {
+ execution.TungstenSort(sortExprs, global, child)
} else if (sqlContext.conf.externalSortEnabled) {
execution.ExternalSort(sortExprs, global, child)
} else {
http://git-wip-us.apache.org/repos/asf/spark/blob/e7a0976e/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
index 26dbc91..f88a45f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
@@ -229,7 +229,7 @@ private[joins] final class UnsafeHashedRelation(
// write all the values as single byte array
var totalSize = 0L
var i = 0
- while (i < values.size) {
+ while (i < values.length) {
totalSize += values(i).getSizeInBytes + 4 + 4
i += 1
}
@@ -240,7 +240,7 @@ private[joins] final class UnsafeHashedRelation(
out.writeInt(totalSize.toInt)
out.write(key.getBytes)
i = 0
- while (i < values.size) {
+ while (i < values.length) {
// [num of fields] [num of bytes] [row bytes]
// write the integer in native order, so they can be read by UNSAFE.getInt()
if (ByteOrder.nativeOrder() == ByteOrder.BIG_ENDIAN) {
http://git-wip-us.apache.org/repos/asf/spark/blob/e7a0976e/sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala
index f822088..6d903ab 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala
@@ -17,16 +17,14 @@
package org.apache.spark.sql.execution
-import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.errors._
-import org.apache.spark.sql.catalyst.expressions.{Descending, BindReferences, Attribute, SortOrder}
+import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.physical.{UnspecifiedDistribution, OrderedDistribution, Distribution}
import org.apache.spark.sql.types.StructType
import org.apache.spark.util.CompletionIterator
import org.apache.spark.util.collection.ExternalSorter
-import org.apache.spark.util.collection.unsafe.sort.PrefixComparator
////////////////////////////////////////////////////////////////////////////////////////////////////
// This file defines various sort operators.
@@ -97,59 +95,53 @@ case class ExternalSort(
* @param testSpillFrequency Method for configuring periodic spilling in unit tests. If set, will
* spill every `frequency` records.
*/
-case class UnsafeExternalSort(
+case class TungstenSort(
sortOrder: Seq[SortOrder],
global: Boolean,
child: SparkPlan,
testSpillFrequency: Int = 0)
extends UnaryNode {
- private[this] val schema: StructType = child.schema
+ override def outputsUnsafeRows: Boolean = true
+ override def canProcessUnsafeRows: Boolean = true
+ override def canProcessSafeRows: Boolean = false
+
+ override def output: Seq[Attribute] = child.output
+
+ override def outputOrdering: Seq[SortOrder] = sortOrder
override def requiredChildDistribution: Seq[Distribution] =
if (global) OrderedDistribution(sortOrder) :: Nil else UnspecifiedDistribution :: Nil
- protected override def doExecute(): RDD[InternalRow] = attachTree(this, "sort") {
- assert(codegenEnabled, "UnsafeExternalSort requires code generation to be enabled")
- def doSort(iterator: Iterator[InternalRow]): Iterator[InternalRow] = {
- val ordering = newOrdering(sortOrder, child.output)
- val boundSortExpression = BindReferences.bindReference(sortOrder.head, child.output)
- // Hack until we generate separate comparator implementations for ascending vs. descending
- // (or choose to codegen them):
- val prefixComparator = {
- val comp = SortPrefixUtils.getPrefixComparator(boundSortExpression)
- if (sortOrder.head.direction == Descending) {
- new PrefixComparator {
- override def compare(p1: Long, p2: Long): Int = -1 * comp.compare(p1, p2)
- }
- } else {
- comp
- }
- }
- val prefixComputer = {
- val prefixComputer = SortPrefixUtils.getPrefixComputer(boundSortExpression)
- new UnsafeExternalRowSorter.PrefixComputer {
- override def computePrefix(row: InternalRow): Long = prefixComputer(row)
+ protected override def doExecute(): RDD[InternalRow] = {
+ val schema = child.schema
+ val childOutput = child.output
+ child.execute().mapPartitions({ iter =>
+ val ordering = newOrdering(sortOrder, childOutput)
+
+ // The comparator for comparing prefix
+ val boundSortExpression = BindReferences.bindReference(sortOrder.head, childOutput)
+ val prefixComparator = SortPrefixUtils.getPrefixComparator(boundSortExpression)
+
+ // The generator for prefix
+ val prefixProjection = UnsafeProjection.create(Seq(SortPrefix(boundSortExpression)))
+ val prefixComputer = new UnsafeExternalRowSorter.PrefixComputer {
+ override def computePrefix(row: InternalRow): Long = {
+ prefixProjection.apply(row).getLong(0)
}
}
+
val sorter = new UnsafeExternalRowSorter(schema, ordering, prefixComparator, prefixComputer)
if (testSpillFrequency > 0) {
sorter.setTestSpillFrequency(testSpillFrequency)
}
- sorter.sort(iterator)
- }
- child.execute().mapPartitions(doSort, preservesPartitioning = true)
+ sorter.sort(iter.asInstanceOf[Iterator[UnsafeRow]])
+ }, preservesPartitioning = true)
}
- override def output: Seq[Attribute] = child.output
-
- override def outputOrdering: Seq[SortOrder] = sortOrder
-
- override def outputsUnsafeRows: Boolean = true
}
-@DeveloperApi
-object UnsafeExternalSort {
+object TungstenSort {
/**
* Return true if UnsafeExternalSort can sort rows with the given schema, false otherwise.
*/
http://git-wip-us.apache.org/repos/asf/spark/blob/e7a0976e/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala
index 7b75f75..707cd9c 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala
@@ -18,8 +18,7 @@
package org.apache.spark.sql.execution
import org.apache.spark.sql.Row
-import org.apache.spark.sql.catalyst.CatalystTypeConverters
-import org.apache.spark.sql.catalyst.expressions.IsNull
+import org.apache.spark.sql.catalyst.expressions.{Literal, IsNull}
import org.apache.spark.sql.test.TestSQLContext
class RowFormatConvertersSuite extends SparkPlanTest {
@@ -31,7 +30,7 @@ class RowFormatConvertersSuite extends SparkPlanTest {
private val outputsSafe = ExternalSort(Nil, false, PhysicalRDD(Seq.empty, null))
assert(!outputsSafe.outputsUnsafeRows)
- private val outputsUnsafe = UnsafeExternalSort(Nil, false, PhysicalRDD(Seq.empty, null))
+ private val outputsUnsafe = TungstenSort(Nil, false, PhysicalRDD(Seq.empty, null))
assert(outputsUnsafe.outputsUnsafeRows)
test("planner should insert unsafe->safe conversions when required") {
@@ -41,14 +40,14 @@ class RowFormatConvertersSuite extends SparkPlanTest {
}
test("filter can process unsafe rows") {
- val plan = Filter(IsNull(null), outputsUnsafe)
+ val plan = Filter(IsNull(IsNull(Literal(1))), outputsUnsafe)
val preparedPlan = TestSQLContext.prepareForExecution.execute(plan)
- assert(getConverters(preparedPlan).isEmpty)
+ assert(getConverters(preparedPlan).size === 1)
assert(preparedPlan.outputsUnsafeRows)
}
test("filter can process safe rows") {
- val plan = Filter(IsNull(null), outputsSafe)
+ val plan = Filter(IsNull(IsNull(Literal(1))), outputsSafe)
val preparedPlan = TestSQLContext.prepareForExecution.execute(plan)
assert(getConverters(preparedPlan).isEmpty)
assert(!preparedPlan.outputsUnsafeRows)
http://git-wip-us.apache.org/repos/asf/spark/blob/e7a0976e/sql/core/src/test/scala/org/apache/spark/sql/execution/TungstenSortSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/TungstenSortSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/TungstenSortSuite.scala
new file mode 100644
index 0000000..4509635
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/TungstenSortSuite.scala
@@ -0,0 +1,83 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution
+
+import scala.util.Random
+
+import org.scalatest.BeforeAndAfterAll
+
+import org.apache.spark.sql.{RandomDataGenerator, Row, SQLConf}
+import org.apache.spark.sql.catalyst.dsl.expressions._
+import org.apache.spark.sql.test.TestSQLContext
+import org.apache.spark.sql.types._
+
+class TungstenSortSuite extends SparkPlanTest with BeforeAndAfterAll {
+
+ override def beforeAll(): Unit = {
+ TestSQLContext.conf.setConf(SQLConf.CODEGEN_ENABLED, true)
+ }
+
+ override def afterAll(): Unit = {
+ TestSQLContext.conf.setConf(SQLConf.CODEGEN_ENABLED, SQLConf.CODEGEN_ENABLED.defaultValue.get)
+ }
+
+ test("sort followed by limit") {
+ checkThatPlansAgree(
+ (1 to 100).map(v => Tuple1(v)).toDF("a"),
+ (child: SparkPlan) => Limit(10, TungstenSort('a.asc :: Nil, true, child)),
+ (child: SparkPlan) => Limit(10, Sort('a.asc :: Nil, global = true, child)),
+ sortAnswers = false
+ )
+ }
+
+ test("sorting does not crash for large inputs") {
+ val sortOrder = 'a.asc :: Nil
+ val stringLength = 1024 * 1024 * 2
+ checkThatPlansAgree(
+ Seq(Tuple1("a" * stringLength), Tuple1("b" * stringLength)).toDF("a").repartition(1),
+ TungstenSort(sortOrder, global = true, _: SparkPlan, testSpillFrequency = 1),
+ Sort(sortOrder, global = true, _: SparkPlan),
+ sortAnswers = false
+ )
+ }
+
+ // Test sorting on different data types
+ for (
+ dataType <- DataTypeTestUtils.atomicTypes ++ Set(NullType)
+ if !dataType.isInstanceOf[DecimalType]; // We don't have an unsafe representation for decimals
+ nullable <- Seq(true, false);
+ sortOrder <- Seq('a.asc :: Nil, 'a.desc :: Nil);
+ randomDataGenerator <- RandomDataGenerator.forType(dataType, nullable)
+ ) {
+ test(s"sorting on $dataType with nullable=$nullable, sortOrder=$sortOrder") {
+ val inputData = Seq.fill(1000)(randomDataGenerator())
+ val inputDf = TestSQLContext.createDataFrame(
+ TestSQLContext.sparkContext.parallelize(Random.shuffle(inputData).map(v => Row(v))),
+ StructType(StructField("a", dataType, nullable = true) :: Nil)
+ )
+ assert(TungstenSort.supportsSchema(inputDf.schema))
+ checkThatPlansAgree(
+ inputDf,
+ plan => ConvertToSafe(
+ TungstenSort(sortOrder, global = true, plan: SparkPlan, testSpillFrequency = 23)),
+ Sort(sortOrder, global = true, _: SparkPlan),
+ sortAnswers = false
+ )
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/spark/blob/e7a0976e/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeExternalSortSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeExternalSortSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeExternalSortSuite.scala
deleted file mode 100644
index 138636b..0000000
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeExternalSortSuite.scala
+++ /dev/null
@@ -1,83 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.execution
-
-import scala.util.Random
-
-import org.scalatest.BeforeAndAfterAll
-
-import org.apache.spark.sql.{RandomDataGenerator, Row, SQLConf}
-import org.apache.spark.sql.catalyst.dsl.expressions._
-import org.apache.spark.sql.test.TestSQLContext
-import org.apache.spark.sql.types._
-
-class UnsafeExternalSortSuite extends SparkPlanTest with BeforeAndAfterAll {
-
- override def beforeAll(): Unit = {
- TestSQLContext.conf.setConf(SQLConf.CODEGEN_ENABLED, true)
- }
-
- override def afterAll(): Unit = {
- TestSQLContext.conf.setConf(SQLConf.CODEGEN_ENABLED, SQLConf.CODEGEN_ENABLED.defaultValue.get)
- }
-
- test("sort followed by limit") {
- checkThatPlansAgree(
- (1 to 100).map(v => Tuple1(v)).toDF("a"),
- (child: SparkPlan) => Limit(10, UnsafeExternalSort('a.asc :: Nil, true, child)),
- (child: SparkPlan) => Limit(10, Sort('a.asc :: Nil, global = true, child)),
- sortAnswers = false
- )
- }
-
- test("sorting does not crash for large inputs") {
- val sortOrder = 'a.asc :: Nil
- val stringLength = 1024 * 1024 * 2
- checkThatPlansAgree(
- Seq(Tuple1("a" * stringLength), Tuple1("b" * stringLength)).toDF("a").repartition(1),
- UnsafeExternalSort(sortOrder, global = true, _: SparkPlan, testSpillFrequency = 1),
- Sort(sortOrder, global = true, _: SparkPlan),
- sortAnswers = false
- )
- }
-
- // Test sorting on different data types
- for (
- dataType <- DataTypeTestUtils.atomicTypes ++ Set(NullType)
- if !dataType.isInstanceOf[DecimalType]; // We don't have an unsafe representation for decimals
- nullable <- Seq(true, false);
- sortOrder <- Seq('a.asc :: Nil, 'a.desc :: Nil);
- randomDataGenerator <- RandomDataGenerator.forType(dataType, nullable)
- ) {
- test(s"sorting on $dataType with nullable=$nullable, sortOrder=$sortOrder") {
- val inputData = Seq.fill(1000)(randomDataGenerator())
- val inputDf = TestSQLContext.createDataFrame(
- TestSQLContext.sparkContext.parallelize(Random.shuffle(inputData).map(v => Row(v))),
- StructType(StructField("a", dataType, nullable = true) :: Nil)
- )
- assert(UnsafeExternalSort.supportsSchema(inputDf.schema))
- checkThatPlansAgree(
- inputDf,
- plan => ConvertToSafe(
- UnsafeExternalSort(sortOrder, global = true, plan: SparkPlan, testSpillFrequency = 23)),
- Sort(sortOrder, global = true, _: SparkPlan),
- sortAnswers = false
- )
- }
- }
-}
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org