You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by rx...@apache.org on 2015/07/31 02:17:30 UTC

spark git commit: [SPARK-9458][SPARK-9469][SQL] Code generate prefix computation in sorting & moves unsafe conversion out of TungstenSort.

Repository: spark
Updated Branches:
  refs/heads/master df3266951 -> e7a0976e9


[SPARK-9458][SPARK-9469][SQL] Code generate prefix computation in sorting & moves unsafe conversion out of TungstenSort.

Author: Reynold Xin <rx...@databricks.com>

Closes #7803 from rxin/SPARK-9458 and squashes the following commits:

5b032dc [Reynold Xin] Fix string.
b670dbb [Reynold Xin] [SPARK-9458][SPARK-9469][SQL] Code generate prefix computation in sorting & moves unsafe conversion out of TungstenSort.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/e7a0976e
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/e7a0976e
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/e7a0976e

Branch: refs/heads/master
Commit: e7a0976e991f75a7bda99509e2b040daab965ae6
Parents: df32669
Author: Reynold Xin <rx...@databricks.com>
Authored: Thu Jul 30 17:17:27 2015 -0700
Committer: Reynold Xin <rx...@databricks.com>
Committed: Thu Jul 30 17:17:27 2015 -0700

----------------------------------------------------------------------
 .../unsafe/sort/PrefixComparators.java          | 49 +++++++-----
 .../unsafe/sort/PrefixComparatorsSuite.scala    | 22 ++----
 .../sql/execution/UnsafeExternalRowSorter.java  | 27 +++----
 .../sql/catalyst/expressions/SortOrder.scala    | 44 ++++++++++-
 .../spark/sql/execution/SortPrefixUtils.scala   | 64 ++-------------
 .../spark/sql/execution/SparkStrategies.scala   |  4 +-
 .../sql/execution/joins/HashedRelation.scala    |  4 +-
 .../org/apache/spark/sql/execution/sort.scala   | 64 +++++++--------
 .../execution/RowFormatConvertersSuite.scala    | 11 ++-
 .../spark/sql/execution/TungstenSortSuite.scala | 83 ++++++++++++++++++++
 .../sql/execution/UnsafeExternalSortSuite.scala | 83 --------------------
 11 files changed, 216 insertions(+), 239 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/e7a0976e/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java
index 600aff7..4d7e5b3 100644
--- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java
+++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java
@@ -28,9 +28,11 @@ public class PrefixComparators {
   private PrefixComparators() {}
 
   public static final StringPrefixComparator STRING = new StringPrefixComparator();
-  public static final IntegralPrefixComparator INTEGRAL = new IntegralPrefixComparator();
-  public static final FloatPrefixComparator FLOAT = new FloatPrefixComparator();
+  public static final StringPrefixComparatorDesc STRING_DESC = new StringPrefixComparatorDesc();
+  public static final LongPrefixComparator LONG = new LongPrefixComparator();
+  public static final LongPrefixComparatorDesc LONG_DESC = new LongPrefixComparatorDesc();
   public static final DoublePrefixComparator DOUBLE = new DoublePrefixComparator();
+  public static final DoublePrefixComparatorDesc DOUBLE_DESC = new DoublePrefixComparatorDesc();
 
   public static final class StringPrefixComparator extends PrefixComparator {
     @Override
@@ -38,50 +40,55 @@ public class PrefixComparators {
       return UnsignedLongs.compare(aPrefix, bPrefix);
     }
 
-    public long computePrefix(UTF8String value) {
+    public static long computePrefix(UTF8String value) {
       return value == null ? 0L : value.getPrefix();
     }
   }
 
-  /**
-   * Prefix comparator for all integral types (boolean, byte, short, int, long).
-   */
-  public static final class IntegralPrefixComparator extends PrefixComparator {
+  public static final class StringPrefixComparatorDesc extends PrefixComparator {
+    @Override
+    public int compare(long bPrefix, long aPrefix) {
+      return UnsignedLongs.compare(aPrefix, bPrefix);
+    }
+  }
+
+  public static final class LongPrefixComparator extends PrefixComparator {
     @Override
     public int compare(long a, long b) {
       return (a < b) ? -1 : (a > b) ? 1 : 0;
     }
+  }
 
-    public final long NULL_PREFIX = Long.MIN_VALUE;
+  public static final class LongPrefixComparatorDesc extends PrefixComparator {
+    @Override
+    public int compare(long b, long a) {
+      return (a < b) ? -1 : (a > b) ? 1 : 0;
+    }
   }
 
-  public static final class FloatPrefixComparator extends PrefixComparator {
+  public static final class DoublePrefixComparator extends PrefixComparator {
     @Override
     public int compare(long aPrefix, long bPrefix) {
-      float a = Float.intBitsToFloat((int) aPrefix);
-      float b = Float.intBitsToFloat((int) bPrefix);
-      return Utils.nanSafeCompareFloats(a, b);
+      double a = Double.longBitsToDouble(aPrefix);
+      double b = Double.longBitsToDouble(bPrefix);
+      return Utils.nanSafeCompareDoubles(a, b);
     }
 
-    public long computePrefix(float value) {
-      return Float.floatToIntBits(value) & 0xffffffffL;
+    public static long computePrefix(double value) {
+      return Double.doubleToLongBits(value);
     }
-
-    public final long NULL_PREFIX = computePrefix(Float.NEGATIVE_INFINITY);
   }
 
-  public static final class DoublePrefixComparator extends PrefixComparator {
+  public static final class DoublePrefixComparatorDesc extends PrefixComparator {
     @Override
-    public int compare(long aPrefix, long bPrefix) {
+    public int compare(long bPrefix, long aPrefix) {
       double a = Double.longBitsToDouble(aPrefix);
       double b = Double.longBitsToDouble(bPrefix);
       return Utils.nanSafeCompareDoubles(a, b);
     }
 
-    public long computePrefix(double value) {
+    public static long computePrefix(double value) {
       return Double.doubleToLongBits(value);
     }
-
-    public final long NULL_PREFIX = computePrefix(Double.NEGATIVE_INFINITY);
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/e7a0976e/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/PrefixComparatorsSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/PrefixComparatorsSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/PrefixComparatorsSuite.scala
index cf53a8a..26a2e96 100644
--- a/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/PrefixComparatorsSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/PrefixComparatorsSuite.scala
@@ -29,8 +29,8 @@ class PrefixComparatorsSuite extends SparkFunSuite with PropertyChecks {
     def testPrefixComparison(s1: String, s2: String): Unit = {
       val utf8string1 = UTF8String.fromString(s1)
       val utf8string2 = UTF8String.fromString(s2)
-      val s1Prefix = PrefixComparators.STRING.computePrefix(utf8string1)
-      val s2Prefix = PrefixComparators.STRING.computePrefix(utf8string2)
+      val s1Prefix = PrefixComparators.StringPrefixComparator.computePrefix(utf8string1)
+      val s2Prefix = PrefixComparators.StringPrefixComparator.computePrefix(utf8string2)
       val prefixComparisonResult = PrefixComparators.STRING.compare(s1Prefix, s2Prefix)
 
       val cmp = UnsignedBytes.lexicographicalComparator().compare(
@@ -55,27 +55,15 @@ class PrefixComparatorsSuite extends SparkFunSuite with PropertyChecks {
     forAll { (s1: String, s2: String) => testPrefixComparison(s1, s2) }
   }
 
-  test("float prefix comparator handles NaN properly") {
-    val nan1: Float = java.lang.Float.intBitsToFloat(0x7f800001)
-    val nan2: Float = java.lang.Float.intBitsToFloat(0x7fffffff)
-    assert(nan1.isNaN)
-    assert(nan2.isNaN)
-    val nan1Prefix = PrefixComparators.FLOAT.computePrefix(nan1)
-    val nan2Prefix = PrefixComparators.FLOAT.computePrefix(nan2)
-    assert(nan1Prefix === nan2Prefix)
-    val floatMaxPrefix = PrefixComparators.FLOAT.computePrefix(Float.MaxValue)
-    assert(PrefixComparators.FLOAT.compare(nan1Prefix, floatMaxPrefix) === 1)
-  }
-
   test("double prefix comparator handles NaNs properly") {
     val nan1: Double = java.lang.Double.longBitsToDouble(0x7ff0000000000001L)
     val nan2: Double = java.lang.Double.longBitsToDouble(0x7fffffffffffffffL)
     assert(nan1.isNaN)
     assert(nan2.isNaN)
-    val nan1Prefix = PrefixComparators.DOUBLE.computePrefix(nan1)
-    val nan2Prefix = PrefixComparators.DOUBLE.computePrefix(nan2)
+    val nan1Prefix = PrefixComparators.DoublePrefixComparator.computePrefix(nan1)
+    val nan2Prefix = PrefixComparators.DoublePrefixComparator.computePrefix(nan2)
     assert(nan1Prefix === nan2Prefix)
-    val doubleMaxPrefix = PrefixComparators.DOUBLE.computePrefix(Double.MaxValue)
+    val doubleMaxPrefix = PrefixComparators.DoublePrefixComparator.computePrefix(Double.MaxValue)
     assert(PrefixComparators.DOUBLE.compare(nan1Prefix, doubleMaxPrefix) === 1)
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/e7a0976e/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java
index 4c3f2c6..68c49fe 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java
@@ -48,7 +48,6 @@ final class UnsafeExternalRowSorter {
   private long numRowsInserted = 0;
 
   private final StructType schema;
-  private final UnsafeProjection unsafeProjection;
   private final PrefixComputer prefixComputer;
   private final UnsafeExternalSorter sorter;
 
@@ -62,7 +61,6 @@ final class UnsafeExternalRowSorter {
       PrefixComparator prefixComparator,
       PrefixComputer prefixComputer) throws IOException {
     this.schema = schema;
-    this.unsafeProjection = UnsafeProjection.create(schema);
     this.prefixComputer = prefixComputer;
     final SparkEnv sparkEnv = SparkEnv.get();
     final TaskContext taskContext = TaskContext.get();
@@ -88,13 +86,12 @@ final class UnsafeExternalRowSorter {
   }
 
   @VisibleForTesting
-  void insertRow(InternalRow row) throws IOException {
-    UnsafeRow unsafeRow = unsafeProjection.apply(row);
+  void insertRow(UnsafeRow row) throws IOException {
     final long prefix = prefixComputer.computePrefix(row);
     sorter.insertRecord(
-      unsafeRow.getBaseObject(),
-      unsafeRow.getBaseOffset(),
-      unsafeRow.getSizeInBytes(),
+      row.getBaseObject(),
+      row.getBaseOffset(),
+      row.getSizeInBytes(),
       prefix
     );
     numRowsInserted++;
@@ -113,7 +110,7 @@ final class UnsafeExternalRowSorter {
   }
 
   @VisibleForTesting
-  Iterator<InternalRow> sort() throws IOException {
+  Iterator<UnsafeRow> sort() throws IOException {
     try {
       final UnsafeSorterIterator sortedIterator = sorter.getSortedIterator();
       if (!sortedIterator.hasNext()) {
@@ -121,7 +118,7 @@ final class UnsafeExternalRowSorter {
         // here in order to prevent memory leaks.
         cleanupResources();
       }
-      return new AbstractScalaRowIterator() {
+      return new AbstractScalaRowIterator<UnsafeRow>() {
 
         private final int numFields = schema.length();
         private UnsafeRow row = new UnsafeRow();
@@ -132,7 +129,7 @@ final class UnsafeExternalRowSorter {
         }
 
         @Override
-        public InternalRow next() {
+        public UnsafeRow next() {
           try {
             sortedIterator.loadNext();
             row.pointTo(
@@ -164,11 +161,11 @@ final class UnsafeExternalRowSorter {
   }
 
 
-  public Iterator<InternalRow> sort(Iterator<InternalRow> inputIterator) throws IOException {
-      while (inputIterator.hasNext()) {
-        insertRow(inputIterator.next());
-      }
-      return sort();
+  public Iterator<UnsafeRow> sort(Iterator<UnsafeRow> inputIterator) throws IOException {
+    while (inputIterator.hasNext()) {
+      insertRow(inputIterator.next());
+    }
+    return sort();
   }
 
   /**

http://git-wip-us.apache.org/repos/asf/spark/blob/e7a0976e/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala
index 3f436c0..9fe877f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala
@@ -17,7 +17,10 @@
 
 package org.apache.spark.sql.catalyst.expressions
 
-import org.apache.spark.sql.types.DataType
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, CodeGenContext}
+import org.apache.spark.sql.types._
+import org.apache.spark.util.collection.unsafe.sort.PrefixComparators.DoublePrefixComparator
 
 abstract sealed class SortDirection
 case object Ascending extends SortDirection
@@ -37,4 +40,43 @@ case class SortOrder(child: Expression, direction: SortDirection)
   override def nullable: Boolean = child.nullable
 
   override def toString: String = s"$child ${if (direction == Ascending) "ASC" else "DESC"}"
+
+  def isAscending: Boolean = direction == Ascending
+}
+
+/**
+ * An expression to generate a 64-bit long prefix used in sorting.
+ */
+case class SortPrefix(child: SortOrder) extends UnaryExpression {
+
+  override def eval(input: InternalRow): Any = throw new UnsupportedOperationException
+
+  override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
+    val childCode = child.child.gen(ctx)
+    val input = childCode.primitive
+    val DoublePrefixCmp = classOf[DoublePrefixComparator].getName
+
+    val (nullValue: Long, prefixCode: String) = child.child.dataType match {
+      case BooleanType =>
+        (Long.MinValue, s"$input ? 1L : 0L")
+      case _: IntegralType =>
+        (Long.MinValue, s"(long) $input")
+      case FloatType | DoubleType =>
+        (DoublePrefixComparator.computePrefix(Double.NegativeInfinity),
+          s"$DoublePrefixCmp.computePrefix((double)$input)")
+      case StringType => (0L, s"$input.getPrefix()")
+      case _ => (0L, "0L")
+    }
+
+    childCode.code +
+    s"""
+      |long ${ev.primitive} = ${nullValue}L;
+      |boolean ${ev.isNull} = false;
+      |if (!${childCode.isNull}) {
+      |  ${ev.primitive} = $prefixCode;
+      |}
+    """.stripMargin
+  }
+
+  override def dataType: DataType = LongType
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/e7a0976e/sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala
index 2dee354..a2145b1 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala
@@ -18,10 +18,8 @@
 
 package org.apache.spark.sql.execution
 
-import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions.SortOrder
 import org.apache.spark.sql.types._
-import org.apache.spark.unsafe.types.UTF8String
 import org.apache.spark.util.collection.unsafe.sort.{PrefixComparators, PrefixComparator}
 
 
@@ -37,61 +35,15 @@ object SortPrefixUtils {
 
   def getPrefixComparator(sortOrder: SortOrder): PrefixComparator = {
     sortOrder.dataType match {
-      case StringType => PrefixComparators.STRING
-      case BooleanType | ByteType | ShortType | IntegerType | LongType => PrefixComparators.INTEGRAL
-      case FloatType => PrefixComparators.FLOAT
-      case DoubleType => PrefixComparators.DOUBLE
+      case StringType if sortOrder.isAscending => PrefixComparators.STRING
+      case StringType if !sortOrder.isAscending => PrefixComparators.STRING_DESC
+      case BooleanType | ByteType | ShortType | IntegerType | LongType if sortOrder.isAscending =>
+        PrefixComparators.LONG
+      case BooleanType | ByteType | ShortType | IntegerType | LongType if !sortOrder.isAscending =>
+        PrefixComparators.LONG_DESC
+      case FloatType | DoubleType if sortOrder.isAscending => PrefixComparators.DOUBLE
+      case FloatType | DoubleType if !sortOrder.isAscending => PrefixComparators.DOUBLE_DESC
       case _ => NoOpPrefixComparator
     }
   }
-
-  def getPrefixComputer(sortOrder: SortOrder): InternalRow => Long = {
-    sortOrder.dataType match {
-      case StringType => (row: InternalRow) => {
-        PrefixComparators.STRING.computePrefix(sortOrder.child.eval(row).asInstanceOf[UTF8String])
-      }
-      case BooleanType =>
-        (row: InternalRow) => {
-          val exprVal = sortOrder.child.eval(row)
-          if (exprVal == null) PrefixComparators.INTEGRAL.NULL_PREFIX
-          else if (sortOrder.child.eval(row).asInstanceOf[Boolean]) 1
-          else 0
-        }
-      case ByteType =>
-        (row: InternalRow) => {
-          val exprVal = sortOrder.child.eval(row)
-          if (exprVal == null) PrefixComparators.INTEGRAL.NULL_PREFIX
-          else sortOrder.child.eval(row).asInstanceOf[Byte]
-        }
-      case ShortType =>
-        (row: InternalRow) => {
-          val exprVal = sortOrder.child.eval(row)
-          if (exprVal == null) PrefixComparators.INTEGRAL.NULL_PREFIX
-          else sortOrder.child.eval(row).asInstanceOf[Short]
-        }
-      case IntegerType =>
-        (row: InternalRow) => {
-          val exprVal = sortOrder.child.eval(row)
-          if (exprVal == null) PrefixComparators.INTEGRAL.NULL_PREFIX
-          else sortOrder.child.eval(row).asInstanceOf[Int]
-        }
-      case LongType =>
-        (row: InternalRow) => {
-          val exprVal = sortOrder.child.eval(row)
-          if (exprVal == null) PrefixComparators.INTEGRAL.NULL_PREFIX
-          else sortOrder.child.eval(row).asInstanceOf[Long]
-        }
-      case FloatType => (row: InternalRow) => {
-        val exprVal = sortOrder.child.eval(row)
-        if (exprVal == null) PrefixComparators.FLOAT.NULL_PREFIX
-        else PrefixComparators.FLOAT.computePrefix(sortOrder.child.eval(row).asInstanceOf[Float])
-      }
-      case DoubleType => (row: InternalRow) => {
-        val exprVal = sortOrder.child.eval(row)
-        if (exprVal == null) PrefixComparators.DOUBLE.NULL_PREFIX
-        else PrefixComparators.DOUBLE.computePrefix(sortOrder.child.eval(row).asInstanceOf[Double])
-      }
-      case _ => (row: InternalRow) => 0L
-    }
-  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/e7a0976e/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index 52a9b02..03d24a8 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -341,8 +341,8 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
      */
     def getSortOperator(sortExprs: Seq[SortOrder], global: Boolean, child: SparkPlan): SparkPlan = {
       if (sqlContext.conf.unsafeEnabled && sqlContext.conf.codegenEnabled &&
-        UnsafeExternalSort.supportsSchema(child.schema)) {
-        execution.UnsafeExternalSort(sortExprs, global, child)
+        TungstenSort.supportsSchema(child.schema)) {
+        execution.TungstenSort(sortExprs, global, child)
       } else if (sqlContext.conf.externalSortEnabled) {
         execution.ExternalSort(sortExprs, global, child)
       } else {

http://git-wip-us.apache.org/repos/asf/spark/blob/e7a0976e/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
index 26dbc91..f88a45f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
@@ -229,7 +229,7 @@ private[joins] final class UnsafeHashedRelation(
       // write all the values as single byte array
       var totalSize = 0L
       var i = 0
-      while (i < values.size) {
+      while (i < values.length) {
         totalSize += values(i).getSizeInBytes + 4 + 4
         i += 1
       }
@@ -240,7 +240,7 @@ private[joins] final class UnsafeHashedRelation(
       out.writeInt(totalSize.toInt)
       out.write(key.getBytes)
       i = 0
-      while (i < values.size) {
+      while (i < values.length) {
         // [num of fields] [num of bytes] [row bytes]
         // write the integer in native order, so they can be read by UNSAFE.getInt()
         if (ByteOrder.nativeOrder() == ByteOrder.BIG_ENDIAN) {

http://git-wip-us.apache.org/repos/asf/spark/blob/e7a0976e/sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala
index f822088..6d903ab 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala
@@ -17,16 +17,14 @@
 
 package org.apache.spark.sql.execution
 
-import org.apache.spark.annotation.DeveloperApi
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.errors._
-import org.apache.spark.sql.catalyst.expressions.{Descending, BindReferences, Attribute, SortOrder}
+import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.plans.physical.{UnspecifiedDistribution, OrderedDistribution, Distribution}
 import org.apache.spark.sql.types.StructType
 import org.apache.spark.util.CompletionIterator
 import org.apache.spark.util.collection.ExternalSorter
-import org.apache.spark.util.collection.unsafe.sort.PrefixComparator
 
 ////////////////////////////////////////////////////////////////////////////////////////////////////
 // This file defines various sort operators.
@@ -97,59 +95,53 @@ case class ExternalSort(
  * @param testSpillFrequency Method for configuring periodic spilling in unit tests. If set, will
  *                           spill every `frequency` records.
  */
-case class UnsafeExternalSort(
+case class TungstenSort(
     sortOrder: Seq[SortOrder],
     global: Boolean,
     child: SparkPlan,
     testSpillFrequency: Int = 0)
   extends UnaryNode {
 
-  private[this] val schema: StructType = child.schema
+  override def outputsUnsafeRows: Boolean = true
+  override def canProcessUnsafeRows: Boolean = true
+  override def canProcessSafeRows: Boolean = false
+
+  override def output: Seq[Attribute] = child.output
+
+  override def outputOrdering: Seq[SortOrder] = sortOrder
 
   override def requiredChildDistribution: Seq[Distribution] =
     if (global) OrderedDistribution(sortOrder) :: Nil else UnspecifiedDistribution :: Nil
 
-  protected override def doExecute(): RDD[InternalRow] = attachTree(this, "sort") {
-    assert(codegenEnabled, "UnsafeExternalSort requires code generation to be enabled")
-    def doSort(iterator: Iterator[InternalRow]): Iterator[InternalRow] = {
-      val ordering = newOrdering(sortOrder, child.output)
-      val boundSortExpression = BindReferences.bindReference(sortOrder.head, child.output)
-      // Hack until we generate separate comparator implementations for ascending vs. descending
-      // (or choose to codegen them):
-      val prefixComparator = {
-        val comp = SortPrefixUtils.getPrefixComparator(boundSortExpression)
-        if (sortOrder.head.direction == Descending) {
-          new PrefixComparator {
-            override def compare(p1: Long, p2: Long): Int = -1 * comp.compare(p1, p2)
-          }
-        } else {
-          comp
-        }
-      }
-      val prefixComputer = {
-        val prefixComputer = SortPrefixUtils.getPrefixComputer(boundSortExpression)
-        new UnsafeExternalRowSorter.PrefixComputer {
-          override def computePrefix(row: InternalRow): Long = prefixComputer(row)
+  protected override def doExecute(): RDD[InternalRow] = {
+    val schema = child.schema
+    val childOutput = child.output
+    child.execute().mapPartitions({ iter =>
+      val ordering = newOrdering(sortOrder, childOutput)
+
+      // The comparator for comparing prefix
+      val boundSortExpression = BindReferences.bindReference(sortOrder.head, childOutput)
+      val prefixComparator = SortPrefixUtils.getPrefixComparator(boundSortExpression)
+
+      // The generator for prefix
+      val prefixProjection = UnsafeProjection.create(Seq(SortPrefix(boundSortExpression)))
+      val prefixComputer = new UnsafeExternalRowSorter.PrefixComputer {
+        override def computePrefix(row: InternalRow): Long = {
+          prefixProjection.apply(row).getLong(0)
         }
       }
+
       val sorter = new UnsafeExternalRowSorter(schema, ordering, prefixComparator, prefixComputer)
       if (testSpillFrequency > 0) {
         sorter.setTestSpillFrequency(testSpillFrequency)
       }
-      sorter.sort(iterator)
-    }
-    child.execute().mapPartitions(doSort, preservesPartitioning = true)
+      sorter.sort(iter.asInstanceOf[Iterator[UnsafeRow]])
+    }, preservesPartitioning = true)
   }
 
-  override def output: Seq[Attribute] = child.output
-
-  override def outputOrdering: Seq[SortOrder] = sortOrder
-
-  override def outputsUnsafeRows: Boolean = true
 }
 
-@DeveloperApi
-object UnsafeExternalSort {
+object TungstenSort {
   /**
    * Return true if UnsafeExternalSort can sort rows with the given schema, false otherwise.
    */

http://git-wip-us.apache.org/repos/asf/spark/blob/e7a0976e/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala
index 7b75f75..707cd9c 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala
@@ -18,8 +18,7 @@
 package org.apache.spark.sql.execution
 
 import org.apache.spark.sql.Row
-import org.apache.spark.sql.catalyst.CatalystTypeConverters
-import org.apache.spark.sql.catalyst.expressions.IsNull
+import org.apache.spark.sql.catalyst.expressions.{Literal, IsNull}
 import org.apache.spark.sql.test.TestSQLContext
 
 class RowFormatConvertersSuite extends SparkPlanTest {
@@ -31,7 +30,7 @@ class RowFormatConvertersSuite extends SparkPlanTest {
 
   private val outputsSafe = ExternalSort(Nil, false, PhysicalRDD(Seq.empty, null))
   assert(!outputsSafe.outputsUnsafeRows)
-  private val outputsUnsafe = UnsafeExternalSort(Nil, false, PhysicalRDD(Seq.empty, null))
+  private val outputsUnsafe = TungstenSort(Nil, false, PhysicalRDD(Seq.empty, null))
   assert(outputsUnsafe.outputsUnsafeRows)
 
   test("planner should insert unsafe->safe conversions when required") {
@@ -41,14 +40,14 @@ class RowFormatConvertersSuite extends SparkPlanTest {
   }
 
   test("filter can process unsafe rows") {
-    val plan = Filter(IsNull(null), outputsUnsafe)
+    val plan = Filter(IsNull(IsNull(Literal(1))), outputsUnsafe)
     val preparedPlan = TestSQLContext.prepareForExecution.execute(plan)
-    assert(getConverters(preparedPlan).isEmpty)
+    assert(getConverters(preparedPlan).size === 1)
     assert(preparedPlan.outputsUnsafeRows)
   }
 
   test("filter can process safe rows") {
-    val plan = Filter(IsNull(null), outputsSafe)
+    val plan = Filter(IsNull(IsNull(Literal(1))), outputsSafe)
     val preparedPlan = TestSQLContext.prepareForExecution.execute(plan)
     assert(getConverters(preparedPlan).isEmpty)
     assert(!preparedPlan.outputsUnsafeRows)

http://git-wip-us.apache.org/repos/asf/spark/blob/e7a0976e/sql/core/src/test/scala/org/apache/spark/sql/execution/TungstenSortSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/TungstenSortSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/TungstenSortSuite.scala
new file mode 100644
index 0000000..4509635
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/TungstenSortSuite.scala
@@ -0,0 +1,83 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution
+
+import scala.util.Random
+
+import org.scalatest.BeforeAndAfterAll
+
+import org.apache.spark.sql.{RandomDataGenerator, Row, SQLConf}
+import org.apache.spark.sql.catalyst.dsl.expressions._
+import org.apache.spark.sql.test.TestSQLContext
+import org.apache.spark.sql.types._
+
+class TungstenSortSuite extends SparkPlanTest with BeforeAndAfterAll {
+
+  override def beforeAll(): Unit = {
+    TestSQLContext.conf.setConf(SQLConf.CODEGEN_ENABLED, true)
+  }
+
+  override def afterAll(): Unit = {
+    TestSQLContext.conf.setConf(SQLConf.CODEGEN_ENABLED, SQLConf.CODEGEN_ENABLED.defaultValue.get)
+  }
+
+  test("sort followed by limit") {
+    checkThatPlansAgree(
+      (1 to 100).map(v => Tuple1(v)).toDF("a"),
+      (child: SparkPlan) => Limit(10, TungstenSort('a.asc :: Nil, true, child)),
+      (child: SparkPlan) => Limit(10, Sort('a.asc :: Nil, global = true, child)),
+      sortAnswers = false
+    )
+  }
+
+  test("sorting does not crash for large inputs") {
+    val sortOrder = 'a.asc :: Nil
+    val stringLength = 1024 * 1024 * 2
+    checkThatPlansAgree(
+      Seq(Tuple1("a" * stringLength), Tuple1("b" * stringLength)).toDF("a").repartition(1),
+      TungstenSort(sortOrder, global = true, _: SparkPlan, testSpillFrequency = 1),
+      Sort(sortOrder, global = true, _: SparkPlan),
+      sortAnswers = false
+    )
+  }
+
+  // Test sorting on different data types
+  for (
+    dataType <- DataTypeTestUtils.atomicTypes ++ Set(NullType)
+    if !dataType.isInstanceOf[DecimalType]; // We don't have an unsafe representation for decimals
+    nullable <- Seq(true, false);
+    sortOrder <- Seq('a.asc :: Nil, 'a.desc :: Nil);
+    randomDataGenerator <- RandomDataGenerator.forType(dataType, nullable)
+  ) {
+    test(s"sorting on $dataType with nullable=$nullable, sortOrder=$sortOrder") {
+      val inputData = Seq.fill(1000)(randomDataGenerator())
+      val inputDf = TestSQLContext.createDataFrame(
+        TestSQLContext.sparkContext.parallelize(Random.shuffle(inputData).map(v => Row(v))),
+        StructType(StructField("a", dataType, nullable = true) :: Nil)
+      )
+      assert(TungstenSort.supportsSchema(inputDf.schema))
+      checkThatPlansAgree(
+        inputDf,
+        plan => ConvertToSafe(
+          TungstenSort(sortOrder, global = true, plan: SparkPlan, testSpillFrequency = 23)),
+        Sort(sortOrder, global = true, _: SparkPlan),
+        sortAnswers = false
+      )
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/e7a0976e/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeExternalSortSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeExternalSortSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeExternalSortSuite.scala
deleted file mode 100644
index 138636b..0000000
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeExternalSortSuite.scala
+++ /dev/null
@@ -1,83 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.execution
-
-import scala.util.Random
-
-import org.scalatest.BeforeAndAfterAll
-
-import org.apache.spark.sql.{RandomDataGenerator, Row, SQLConf}
-import org.apache.spark.sql.catalyst.dsl.expressions._
-import org.apache.spark.sql.test.TestSQLContext
-import org.apache.spark.sql.types._
-
-class UnsafeExternalSortSuite extends SparkPlanTest with BeforeAndAfterAll {
-
-  override def beforeAll(): Unit = {
-    TestSQLContext.conf.setConf(SQLConf.CODEGEN_ENABLED, true)
-  }
-
-  override def afterAll(): Unit = {
-    TestSQLContext.conf.setConf(SQLConf.CODEGEN_ENABLED, SQLConf.CODEGEN_ENABLED.defaultValue.get)
-  }
-
-  test("sort followed by limit") {
-    checkThatPlansAgree(
-      (1 to 100).map(v => Tuple1(v)).toDF("a"),
-      (child: SparkPlan) => Limit(10, UnsafeExternalSort('a.asc :: Nil, true, child)),
-      (child: SparkPlan) => Limit(10, Sort('a.asc :: Nil, global = true, child)),
-      sortAnswers = false
-    )
-  }
-
-  test("sorting does not crash for large inputs") {
-    val sortOrder = 'a.asc :: Nil
-    val stringLength = 1024 * 1024 * 2
-    checkThatPlansAgree(
-      Seq(Tuple1("a" * stringLength), Tuple1("b" * stringLength)).toDF("a").repartition(1),
-      UnsafeExternalSort(sortOrder, global = true, _: SparkPlan, testSpillFrequency = 1),
-      Sort(sortOrder, global = true, _: SparkPlan),
-      sortAnswers = false
-    )
-  }
-
-  // Test sorting on different data types
-  for (
-    dataType <- DataTypeTestUtils.atomicTypes ++ Set(NullType)
-    if !dataType.isInstanceOf[DecimalType]; // We don't have an unsafe representation for decimals
-    nullable <- Seq(true, false);
-    sortOrder <- Seq('a.asc :: Nil, 'a.desc :: Nil);
-    randomDataGenerator <- RandomDataGenerator.forType(dataType, nullable)
-  ) {
-    test(s"sorting on $dataType with nullable=$nullable, sortOrder=$sortOrder") {
-      val inputData = Seq.fill(1000)(randomDataGenerator())
-      val inputDf = TestSQLContext.createDataFrame(
-        TestSQLContext.sparkContext.parallelize(Random.shuffle(inputData).map(v => Row(v))),
-        StructType(StructField("a", dataType, nullable = true) :: Nil)
-      )
-      assert(UnsafeExternalSort.supportsSchema(inputDf.schema))
-      checkThatPlansAgree(
-        inputDf,
-        plan => ConvertToSafe(
-          UnsafeExternalSort(sortOrder, global = true, plan: SparkPlan, testSpillFrequency = 23)),
-        Sort(sortOrder, global = true, _: SparkPlan),
-        sortAnswers = false
-      )
-    }
-  }
-}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org