You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by rx...@apache.org on 2016/11/20 05:50:23 UTC

spark git commit: [SPARK-18458][CORE] Fix signed integer overflow problem at an expression in RadixSort.java

Repository: spark
Updated Branches:
  refs/heads/master 856e00420 -> d93b65524


[SPARK-18458][CORE] Fix signed integer overflow problem at an expression in RadixSort.java

## What changes were proposed in this pull request?

This PR avoids that a result of an expression is negative due to signed integer overflow (e.g. 0x10?????? * 8 < 0). This PR casts each operand to `long` before executing a calculation. Since the result is interpreted as long, the result of the expression is positive.

## How was this patch tested?

Manually executed query82 of TPC-DS with 100TB

Author: Kazuaki Ishizaki <is...@jp.ibm.com>

Closes #15907 from kiszk/SPARK-18458.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/d93b6552
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/d93b6552
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/d93b6552

Branch: refs/heads/master
Commit: d93b6552473468df297a08c0bef9ea0bf0f5c13a
Parents: 856e004
Author: Kazuaki Ishizaki <is...@jp.ibm.com>
Authored: Sat Nov 19 21:50:20 2016 -0800
Committer: Reynold Xin <rx...@databricks.com>
Committed: Sat Nov 19 21:50:20 2016 -0800

----------------------------------------------------------------------
 .../util/collection/unsafe/sort/RadixSort.java  | 48 ++++++++++----------
 .../unsafe/sort/UnsafeInMemorySorter.java       |  2 +-
 .../collection/unsafe/sort/RadixSortSuite.scala | 28 ++++++------
 3 files changed, 40 insertions(+), 38 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/d93b6552/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RadixSort.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RadixSort.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RadixSort.java
index 4043617..3dd3184 100644
--- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RadixSort.java
+++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RadixSort.java
@@ -17,6 +17,8 @@
 
 package org.apache.spark.util.collection.unsafe.sort;
 
+import com.google.common.primitives.Ints;
+
 import org.apache.spark.unsafe.Platform;
 import org.apache.spark.unsafe.array.LongArray;
 
@@ -40,14 +42,14 @@ public class RadixSort {
    *         of always copying the data back to position zero for efficiency.
    */
   public static int sort(
-      LongArray array, int numRecords, int startByteIndex, int endByteIndex,
+      LongArray array, long numRecords, int startByteIndex, int endByteIndex,
       boolean desc, boolean signed) {
     assert startByteIndex >= 0 : "startByteIndex (" + startByteIndex + ") should >= 0";
     assert endByteIndex <= 7 : "endByteIndex (" + endByteIndex + ") should <= 7";
     assert endByteIndex > startByteIndex;
     assert numRecords * 2 <= array.size();
-    int inIndex = 0;
-    int outIndex = numRecords;
+    long inIndex = 0;
+    long outIndex = numRecords;
     if (numRecords > 0) {
       long[][] counts = getCounts(array, numRecords, startByteIndex, endByteIndex);
       for (int i = startByteIndex; i <= endByteIndex; i++) {
@@ -55,13 +57,13 @@ public class RadixSort {
           sortAtByte(
             array, numRecords, counts[i], i, inIndex, outIndex,
             desc, signed && i == endByteIndex);
-          int tmp = inIndex;
+          long tmp = inIndex;
           inIndex = outIndex;
           outIndex = tmp;
         }
       }
     }
-    return inIndex;
+    return Ints.checkedCast(inIndex);
   }
 
   /**
@@ -78,14 +80,14 @@ public class RadixSort {
    * @param signed whether this is a signed (two's complement) sort (only applies to last byte).
    */
   private static void sortAtByte(
-      LongArray array, int numRecords, long[] counts, int byteIdx, int inIndex, int outIndex,
+      LongArray array, long numRecords, long[] counts, int byteIdx, long inIndex, long outIndex,
       boolean desc, boolean signed) {
     assert counts.length == 256;
     long[] offsets = transformCountsToOffsets(
-      counts, numRecords, array.getBaseOffset() + outIndex * 8, 8, desc, signed);
+      counts, numRecords, array.getBaseOffset() + outIndex * 8L, 8, desc, signed);
     Object baseObject = array.getBaseObject();
-    long baseOffset = array.getBaseOffset() + inIndex * 8;
-    long maxOffset = baseOffset + numRecords * 8;
+    long baseOffset = array.getBaseOffset() + inIndex * 8L;
+    long maxOffset = baseOffset + numRecords * 8L;
     for (long offset = baseOffset; offset < maxOffset; offset += 8) {
       long value = Platform.getLong(baseObject, offset);
       int bucket = (int)((value >>> (byteIdx * 8)) & 0xff);
@@ -106,13 +108,13 @@ public class RadixSort {
    *         significant byte. If the byte does not need sorting the array will be null.
    */
   private static long[][] getCounts(
-      LongArray array, int numRecords, int startByteIndex, int endByteIndex) {
+      LongArray array, long numRecords, int startByteIndex, int endByteIndex) {
     long[][] counts = new long[8][];
     // Optimization: do a fast pre-pass to determine which byte indices we can skip for sorting.
     // If all the byte values at a particular index are the same we don't need to count it.
     long bitwiseMax = 0;
     long bitwiseMin = -1L;
-    long maxOffset = array.getBaseOffset() + numRecords * 8;
+    long maxOffset = array.getBaseOffset() + numRecords * 8L;
     Object baseObject = array.getBaseObject();
     for (long offset = array.getBaseOffset(); offset < maxOffset; offset += 8) {
       long value = Platform.getLong(baseObject, offset);
@@ -146,18 +148,18 @@ public class RadixSort {
    * @return the input counts array.
    */
   private static long[] transformCountsToOffsets(
-      long[] counts, int numRecords, long outputOffset, int bytesPerRecord,
+      long[] counts, long numRecords, long outputOffset, long bytesPerRecord,
       boolean desc, boolean signed) {
     assert counts.length == 256;
     int start = signed ? 128 : 0;  // output the negative records first (values 129-255).
     if (desc) {
-      int pos = numRecords;
+      long pos = numRecords;
       for (int i = start; i < start + 256; i++) {
         pos -= counts[i & 0xff];
         counts[i & 0xff] = outputOffset + pos * bytesPerRecord;
       }
     } else {
-      int pos = 0;
+      long pos = 0;
       for (int i = start; i < start + 256; i++) {
         long tmp = counts[i & 0xff];
         counts[i & 0xff] = outputOffset + pos * bytesPerRecord;
@@ -176,8 +178,8 @@ public class RadixSort {
    */
   public static int sortKeyPrefixArray(
       LongArray array,
-      int startIndex,
-      int numRecords,
+      long startIndex,
+      long numRecords,
       int startByteIndex,
       int endByteIndex,
       boolean desc,
@@ -186,8 +188,8 @@ public class RadixSort {
     assert endByteIndex <= 7 : "endByteIndex (" + endByteIndex + ") should <= 7";
     assert endByteIndex > startByteIndex;
     assert numRecords * 4 <= array.size();
-    int inIndex = startIndex;
-    int outIndex = startIndex + numRecords * 2;
+    long inIndex = startIndex;
+    long outIndex = startIndex + numRecords * 2L;
     if (numRecords > 0) {
       long[][] counts = getKeyPrefixArrayCounts(
         array, startIndex, numRecords, startByteIndex, endByteIndex);
@@ -196,13 +198,13 @@ public class RadixSort {
           sortKeyPrefixArrayAtByte(
             array, numRecords, counts[i], i, inIndex, outIndex,
             desc, signed && i == endByteIndex);
-          int tmp = inIndex;
+          long tmp = inIndex;
           inIndex = outIndex;
           outIndex = tmp;
         }
       }
     }
-    return inIndex;
+    return Ints.checkedCast(inIndex);
   }
 
   /**
@@ -210,7 +212,7 @@ public class RadixSort {
    * getCounts with some added parameters but that seems to hurt in benchmarks.
    */
   private static long[][] getKeyPrefixArrayCounts(
-      LongArray array, int startIndex, int numRecords, int startByteIndex, int endByteIndex) {
+      LongArray array, long startIndex, long numRecords, int startByteIndex, int endByteIndex) {
     long[][] counts = new long[8][];
     long bitwiseMax = 0;
     long bitwiseMin = -1L;
@@ -238,11 +240,11 @@ public class RadixSort {
    * Specialization of sortAtByte() for key-prefix arrays.
    */
   private static void sortKeyPrefixArrayAtByte(
-      LongArray array, int numRecords, long[] counts, int byteIdx, int inIndex, int outIndex,
+      LongArray array, long numRecords, long[] counts, int byteIdx, long inIndex, long outIndex,
       boolean desc, boolean signed) {
     assert counts.length == 256;
     long[] offsets = transformCountsToOffsets(
-      counts, numRecords, array.getBaseOffset() + outIndex * 8, 16, desc, signed);
+      counts, numRecords, array.getBaseOffset() + outIndex * 8L, 16, desc, signed);
     Object baseObject = array.getBaseObject();
     long baseOffset = array.getBaseOffset() + inIndex * 8L;
     long maxOffset = baseOffset + numRecords * 16L;

http://git-wip-us.apache.org/repos/asf/spark/blob/d93b6552/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java
index 2a71e68..252a35e 100644
--- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java
+++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java
@@ -322,7 +322,7 @@ public final class UnsafeInMemorySorter {
     if (sortComparator != null) {
       if (this.radixSortSupport != null) {
         offset = RadixSort.sortKeyPrefixArray(
-          array, nullBoundaryPos, (pos - nullBoundaryPos) / 2, 0, 7,
+          array, nullBoundaryPos, (pos - nullBoundaryPos) / 2L, 0, 7,
           radixSortSupport.sortDescending(), radixSortSupport.sortSigned());
       } else {
         MemoryBlock unused = new MemoryBlock(

http://git-wip-us.apache.org/repos/asf/spark/blob/d93b6552/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/RadixSortSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/RadixSortSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/RadixSortSuite.scala
index 366ffda..d5956ea 100644
--- a/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/RadixSortSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/RadixSortSuite.scala
@@ -22,6 +22,8 @@ import java.util.{Arrays, Comparator}
 
 import scala.util.Random
 
+import com.google.common.primitives.Ints
+
 import org.apache.spark.SparkFunSuite
 import org.apache.spark.internal.Logging
 import org.apache.spark.unsafe.array.LongArray
@@ -30,7 +32,7 @@ import org.apache.spark.util.collection.Sorter
 import org.apache.spark.util.random.XORShiftRandom
 
 class RadixSortSuite extends SparkFunSuite with Logging {
-  private val N = 10000  // scale this down for more readable results
+  private val N = 10000L  // scale this down for more readable results
 
   /**
    * Describes a type of sort to test, e.g. two's complement descending. Each sort type has
@@ -73,22 +75,22 @@ class RadixSortSuite extends SparkFunSuite with Logging {
       },
       2, 4, false, false, true))
 
-  private def generateTestData(size: Int, rand: => Long): (Array[JLong], LongArray) = {
-    val ref = Array.tabulate[Long](size) { i => rand }
-    val extended = ref ++ Array.fill[Long](size)(0)
+  private def generateTestData(size: Long, rand: => Long): (Array[JLong], LongArray) = {
+    val ref = Array.tabulate[Long](Ints.checkedCast(size)) { i => rand }
+    val extended = ref ++ Array.fill[Long](Ints.checkedCast(size))(0)
     (ref.map(i => new JLong(i)), new LongArray(MemoryBlock.fromLongArray(extended)))
   }
 
-  private def generateKeyPrefixTestData(size: Int, rand: => Long): (LongArray, LongArray) = {
-    val ref = Array.tabulate[Long](size * 2) { i => rand }
-    val extended = ref ++ Array.fill[Long](size * 2)(0)
+  private def generateKeyPrefixTestData(size: Long, rand: => Long): (LongArray, LongArray) = {
+    val ref = Array.tabulate[Long](Ints.checkedCast(size * 2)) { i => rand }
+    val extended = ref ++ Array.fill[Long](Ints.checkedCast(size * 2))(0)
     (new LongArray(MemoryBlock.fromLongArray(ref)),
      new LongArray(MemoryBlock.fromLongArray(extended)))
   }
 
-  private def collectToArray(array: LongArray, offset: Int, length: Int): Array[Long] = {
+  private def collectToArray(array: LongArray, offset: Int, length: Long): Array[Long] = {
     var i = 0
-    val out = new Array[Long](length)
+    val out = new Array[Long](Ints.checkedCast(length))
     while (i < length) {
       out(i) = array.get(offset + i)
       i += 1
@@ -107,15 +109,13 @@ class RadixSortSuite extends SparkFunSuite with Logging {
     }
   }
 
-  private def referenceKeyPrefixSort(buf: LongArray, lo: Int, hi: Int, refCmp: PrefixComparator) {
+  private def referenceKeyPrefixSort(buf: LongArray, lo: Long, hi: Long, refCmp: PrefixComparator) {
     val sortBuffer = new LongArray(MemoryBlock.fromLongArray(new Array[Long](buf.size().toInt)))
     new Sorter(new UnsafeSortDataFormat(sortBuffer)).sort(
-      buf, lo, hi, new Comparator[RecordPointerAndKeyPrefix] {
+      buf, Ints.checkedCast(lo), Ints.checkedCast(hi), new Comparator[RecordPointerAndKeyPrefix] {
         override def compare(
             r1: RecordPointerAndKeyPrefix,
-            r2: RecordPointerAndKeyPrefix): Int = {
-          refCmp.compare(r1.keyPrefix, r2.keyPrefix)
-        }
+            r2: RecordPointerAndKeyPrefix): Int = refCmp.compare(r1.keyPrefix, r2.keyPrefix)
       })
   }
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org