You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by rx...@apache.org on 2016/01/13 21:29:06 UTC

spark git commit: [SPARK-12642][SQL] improve the hash expression to be decoupled from unsafe row

Repository: spark
Updated Branches:
  refs/heads/master e4e0b3f7b -> c2ea79f96


[SPARK-12642][SQL] improve the hash expression to be decoupled from unsafe row

https://issues.apache.org/jira/browse/SPARK-12642

Author: Wenchen Fan <we...@databricks.com>

Closes #10694 from cloud-fan/hash-expr.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/c2ea79f9
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/c2ea79f9
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/c2ea79f9

Branch: refs/heads/master
Commit: c2ea79f96acd076351b48162644ed1cff4c8e090
Parents: e4e0b3f
Author: Wenchen Fan <we...@databricks.com>
Authored: Wed Jan 13 12:29:02 2016 -0800
Committer: Reynold Xin <rx...@databricks.com>
Committed: Wed Jan 13 12:29:02 2016 -0800

----------------------------------------------------------------------
 python/pyspark/sql/functions.py                 |   2 +-
 .../sql/catalyst/expressions/UnsafeRow.java     |   4 -
 .../spark/sql/catalyst/expressions/misc.scala   | 251 ++++++++++++++++++-
 .../expressions/MiscFunctionsSuite.scala        |   6 +-
 .../spark/sql/sources/BucketedWriteSuite.scala  |  26 +-
 .../spark/unsafe/hash/Murmur3_x86_32.java       |  28 ++-
 6 files changed, 288 insertions(+), 29 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/c2ea79f9/python/pyspark/sql/functions.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py
index b0390cb..719eca8 100644
--- a/python/pyspark/sql/functions.py
+++ b/python/pyspark/sql/functions.py
@@ -1023,7 +1023,7 @@ def hash(*cols):
     """Calculates the hash code of given columns, and returns the result as a int column.
 
     >>> sqlContext.createDataFrame([('ABC',)], ['a']).select(hash('a').alias('hash')).collect()
-    [Row(hash=1358996357)]
+    [Row(hash=-757602832)]
     """
     sc = SparkContext._active_spark_context
     jc = sc._jvm.functions.hash(_to_seq(sc, cols, _to_java_column))

http://git-wip-us.apache.org/repos/asf/spark/blob/c2ea79f9/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
index b8d3c49..1a35193 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
@@ -566,10 +566,6 @@ public final class UnsafeRow extends MutableRow implements Externalizable, KryoS
     return Murmur3_x86_32.hashUnsafeWords(baseObject, baseOffset, sizeInBytes, 42);
   }
 
-  public int hashCode(int seed) {
-    return Murmur3_x86_32.hashUnsafeWords(baseObject, baseOffset, sizeInBytes, seed);
-  }
-
   @Override
   public boolean equals(Object other) {
     if (other instanceof UnsafeRow) {

http://git-wip-us.apache.org/repos/asf/spark/blob/c2ea79f9/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala
index cc406a3..4751fbe 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala
@@ -25,8 +25,11 @@ import org.apache.commons.codec.digest.DigestUtils
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
 import org.apache.spark.sql.catalyst.expressions.codegen._
+import org.apache.spark.sql.catalyst.util.{ArrayData, MapData}
 import org.apache.spark.sql.types._
-import org.apache.spark.unsafe.types.UTF8String
+import org.apache.spark.unsafe.hash.Murmur3_x86_32
+import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
+import org.apache.spark.unsafe.Platform
 
 /**
  * A function that calculates an MD5 128-bit checksum and returns it as a hex string
@@ -184,8 +187,31 @@ case class Crc32(child: Expression) extends UnaryExpression with ImplicitCastInp
  * A function that calculates hash value for a group of expressions.  Note that the `seed` argument
  * is not exposed to users and should only be set inside spark SQL.
  *
- * Internally this function will write arguments into an [[UnsafeRow]], and calculate hash code of
- * the unsafe row using murmur3 hasher with a seed.
+ * The hash value for an expression depends on its type and seed:
+ *  - null:               seed
+ *  - boolean:            turn boolean into int, 1 for true, 0 for false, and then use murmur3 to
+ *                        hash this int with seed.
+ *  - byte, short, int:   use murmur3 to hash the input as int with seed.
+ *  - long:               use murmur3 to hash the long input with seed.
+ *  - float:              turn it into int: java.lang.Float.floatToIntBits(input), and hash it.
+ *  - double:             turn it into long: java.lang.Double.doubleToLongBits(input), and hash it.
+ *  - decimal:            if it's a small decimal, i.e. precision <= 18, turn it into long and hash
+ *                        it. Else, turn it into bytes and hash it.
+ *  - calendar interval:  hash `microseconds` first, and use the result as seed to hash `months`.
+ *  - binary:             use murmur3 to hash the bytes with seed.
+ *  - string:             get the bytes of string and hash it.
+ *  - array:              The `result` starts with seed, then use `result` as seed, recursively
+ *                        calculate hash value for each element, and assign the element hash value
+ *                        to `result`.
+ *  - map:                The `result` starts with seed, then use `result` as seed, recursively
+ *                        calculate hash value for each key-value, and assign the key-value hash
+ *                        value to `result`.
+ *  - struct:             The `result` starts with seed, then use `result` as seed, recursively
+ *                        calculate hash value for each field, and assign the field hash value to
+ *                        `result`.
+ *
+ * Finally we aggregate the hash values for each expression by the same way of struct.
+ *
  * We should use this hash function for both shuffle and bucket, so that we can guarantee shuffle
  * and bucketing have same data distribution.
  */
@@ -206,22 +232,225 @@ case class Murmur3Hash(children: Seq[Expression], seed: Int) extends Expression
     }
   }
 
-  private lazy val unsafeProjection = UnsafeProjection.create(children)
+  override def prettyName: String = "hash"
+
+  override def sql: String = s"$prettyName(${children.map(_.sql).mkString(", ")}, $seed)"
 
   override def eval(input: InternalRow): Any = {
-    unsafeProjection(input).hashCode(seed)
+    var hash = seed
+    var i = 0
+    val len = children.length
+    while (i < len) {
+      hash = computeHash(children(i).eval(input), children(i).dataType, hash)
+      i += 1
+    }
+    hash
   }
 
+  private def computeHash(value: Any, dataType: DataType, seed: Int): Int = {
+    def hashInt(i: Int): Int = Murmur3_x86_32.hashInt(i, seed)
+    def hashLong(l: Long): Int = Murmur3_x86_32.hashLong(l, seed)
+
+    value match {
+      case null => seed
+      case b: Boolean => hashInt(if (b) 1 else 0)
+      case b: Byte => hashInt(b)
+      case s: Short => hashInt(s)
+      case i: Int => hashInt(i)
+      case l: Long => hashLong(l)
+      case f: Float => hashInt(java.lang.Float.floatToIntBits(f))
+      case d: Double => hashLong(java.lang.Double.doubleToLongBits(d))
+      case d: Decimal =>
+        val precision = dataType.asInstanceOf[DecimalType].precision
+        if (precision <= Decimal.MAX_LONG_DIGITS) {
+          hashLong(d.toUnscaledLong)
+        } else {
+          val bytes = d.toJavaBigDecimal.unscaledValue().toByteArray
+          Murmur3_x86_32.hashUnsafeBytes(bytes, Platform.BYTE_ARRAY_OFFSET, bytes.length, seed)
+        }
+      case c: CalendarInterval => Murmur3_x86_32.hashInt(c.months, hashLong(c.microseconds))
+      case a: Array[Byte] =>
+        Murmur3_x86_32.hashUnsafeBytes(a, Platform.BYTE_ARRAY_OFFSET, a.length, seed)
+      case s: UTF8String =>
+        Murmur3_x86_32.hashUnsafeBytes(s.getBaseObject, s.getBaseOffset, s.numBytes(), seed)
+
+      case array: ArrayData =>
+        val elementType = dataType match {
+          case udt: UserDefinedType[_] => udt.sqlType.asInstanceOf[ArrayType].elementType
+          case ArrayType(et, _) => et
+        }
+        var result = seed
+        var i = 0
+        while (i < array.numElements()) {
+          result = computeHash(array.get(i, elementType), elementType, result)
+          i += 1
+        }
+        result
+
+      case map: MapData =>
+        val (kt, vt) = dataType match {
+          case udt: UserDefinedType[_] =>
+            val mapType = udt.sqlType.asInstanceOf[MapType]
+            mapType.keyType -> mapType.valueType
+          case MapType(kt, vt, _) => kt -> vt
+        }
+        val keys = map.keyArray()
+        val values = map.valueArray()
+        var result = seed
+        var i = 0
+        while (i < map.numElements()) {
+          result = computeHash(keys.get(i, kt), kt, result)
+          result = computeHash(values.get(i, vt), vt, result)
+          i += 1
+        }
+        result
+
+      case struct: InternalRow =>
+        val types: Array[DataType] = dataType match {
+          case udt: UserDefinedType[_] =>
+            udt.sqlType.asInstanceOf[StructType].map(_.dataType).toArray
+          case StructType(fields) => fields.map(_.dataType)
+        }
+        var result = seed
+        var i = 0
+        val len = struct.numFields
+        while (i < len) {
+          result = computeHash(struct.get(i, types(i)), types(i), result)
+          i += 1
+        }
+        result
+    }
+  }
+
+
   override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
-    val unsafeRow = GenerateUnsafeProjection.createCode(ctx, children)
     ev.isNull = "false"
+    val childrenHash = children.zipWithIndex.map {
+      case (child, dt) =>
+        val childGen = child.gen(ctx)
+        val childHash = computeHash(childGen.value, child.dataType, ev.value, ctx)
+        s"""
+          ${childGen.code}
+          if (!${childGen.isNull}) {
+            ${childHash.code}
+            ${ev.value} = ${childHash.value};
+          }
+        """
+    }.mkString("\n")
     s"""
-      ${unsafeRow.code}
-      final int ${ev.value} = ${unsafeRow.value}.hashCode($seed);
+      int ${ev.value} = $seed;
+      $childrenHash
     """
   }
 
-  override def prettyName: String = "hash"
-
-  override def sql: String = s"$prettyName(${children.map(_.sql).mkString(", ")}, $seed)"
+  private def computeHash(
+      input: String,
+      dataType: DataType,
+      seed: String,
+      ctx: CodeGenContext): GeneratedExpressionCode = {
+    val hasher = classOf[Murmur3_x86_32].getName
+    def hashInt(i: String): GeneratedExpressionCode = inlineValue(s"$hasher.hashInt($i, $seed)")
+    def hashLong(l: String): GeneratedExpressionCode = inlineValue(s"$hasher.hashLong($l, $seed)")
+    def inlineValue(v: String): GeneratedExpressionCode =
+      GeneratedExpressionCode(code = "", isNull = "false", value = v)
+
+    dataType match {
+      case NullType => inlineValue(seed)
+      case BooleanType => hashInt(s"$input ? 1 : 0")
+      case ByteType | ShortType | IntegerType | DateType => hashInt(input)
+      case LongType | TimestampType => hashLong(input)
+      case FloatType => hashInt(s"Float.floatToIntBits($input)")
+      case DoubleType => hashLong(s"Double.doubleToLongBits($input)")
+      case d: DecimalType =>
+        if (d.precision <= Decimal.MAX_LONG_DIGITS) {
+          hashLong(s"$input.toUnscaledLong()")
+        } else {
+          val bytes = ctx.freshName("bytes")
+          val code = s"byte[] $bytes = $input.toJavaBigDecimal().unscaledValue().toByteArray();"
+          val offset = "Platform.BYTE_ARRAY_OFFSET"
+          val result = s"$hasher.hashUnsafeBytes($bytes, $offset, $bytes.length, $seed)"
+          GeneratedExpressionCode(code, "false", result)
+        }
+      case CalendarIntervalType =>
+        val microsecondsHash = s"$hasher.hashLong($input.microseconds, $seed)"
+        val monthsHash = s"$hasher.hashInt($input.months, $microsecondsHash)"
+        inlineValue(monthsHash)
+      case BinaryType =>
+        val offset = "Platform.BYTE_ARRAY_OFFSET"
+        inlineValue(s"$hasher.hashUnsafeBytes($input, $offset, $input.length, $seed)")
+      case StringType =>
+        val baseObject = s"$input.getBaseObject()"
+        val baseOffset = s"$input.getBaseOffset()"
+        val numBytes = s"$input.numBytes()"
+        inlineValue(s"$hasher.hashUnsafeBytes($baseObject, $baseOffset, $numBytes, $seed)")
+
+      case ArrayType(et, _) =>
+        val result = ctx.freshName("result")
+        val index = ctx.freshName("index")
+        val element = ctx.freshName("element")
+        val elementHash = computeHash(element, et, result, ctx)
+        val code =
+          s"""
+            int $result = $seed;
+            for (int $index = 0; $index < $input.numElements(); $index++) {
+              if (!$input.isNullAt($index)) {
+                final ${ctx.javaType(et)} $element = ${ctx.getValue(input, et, index)};
+                ${elementHash.code}
+                $result = ${elementHash.value};
+              }
+            }
+          """
+        GeneratedExpressionCode(code, "false", result)
+
+      case MapType(kt, vt, _) =>
+        val result = ctx.freshName("result")
+        val index = ctx.freshName("index")
+        val keys = ctx.freshName("keys")
+        val values = ctx.freshName("values")
+        val key = ctx.freshName("key")
+        val value = ctx.freshName("value")
+        val keyHash = computeHash(key, kt, result, ctx)
+        val valueHash = computeHash(value, vt, result, ctx)
+        val code =
+          s"""
+            int $result = $seed;
+            final ArrayData $keys = $input.keyArray();
+            final ArrayData $values = $input.valueArray();
+            for (int $index = 0; $index < $input.numElements(); $index++) {
+              final ${ctx.javaType(kt)} $key = ${ctx.getValue(keys, kt, index)};
+              ${keyHash.code}
+              $result = ${keyHash.value};
+              if (!$values.isNullAt($index)) {
+                final ${ctx.javaType(vt)} $value = ${ctx.getValue(values, vt, index)};
+                ${valueHash.code}
+                $result = ${valueHash.value};
+              }
+            }
+          """
+        GeneratedExpressionCode(code, "false", result)
+
+      case StructType(fields) =>
+        val result = ctx.freshName("result")
+        val fieldsHash = fields.map(_.dataType).zipWithIndex.map {
+          case (dt, index) =>
+            val field = ctx.freshName("field")
+            val fieldHash = computeHash(field, dt, result, ctx)
+            s"""
+              if (!$input.isNullAt($index)) {
+                final ${ctx.javaType(dt)} $field = ${ctx.getValue(input, dt, index.toString)};
+                ${fieldHash.code}
+                $result = ${fieldHash.value};
+              }
+            """
+        }.mkString("\n")
+        val code =
+          s"""
+            int $result = $seed;
+            $fieldsHash
+          """
+        GeneratedExpressionCode(code, "false", result)
+
+      case udt: UserDefinedType[_] => computeHash(input, udt.sqlType, seed, ctx)
+    }
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/c2ea79f9/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscFunctionsSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscFunctionsSuite.scala
index 64161be..75131a6 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscFunctionsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscFunctionsSuite.scala
@@ -79,7 +79,8 @@ class MiscFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
       .add("long", LongType)
       .add("float", FloatType)
       .add("double", DoubleType)
-      .add("decimal", DecimalType.SYSTEM_DEFAULT)
+      .add("bigDecimal", DecimalType.SYSTEM_DEFAULT)
+      .add("smallDecimal", DecimalType.USER_DEFAULT)
       .add("string", StringType)
       .add("binary", BinaryType)
       .add("date", DateType)
@@ -126,7 +127,8 @@ class MiscFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
         val literals = input.toSeq(inputSchema).zip(inputSchema.map(_.dataType)).map {
           case (value, dt) => Literal.create(value, dt)
         }
-        checkEvaluation(Murmur3Hash(literals, seed), input.hashCode(seed))
+        // Only test the interpreted version has same result with codegen version.
+        checkEvaluation(Murmur3Hash(literals, seed), Murmur3Hash(literals, seed).eval())
       }
     }
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/c2ea79f9/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala
----------------------------------------------------------------------
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala
index 7f17457..b718b7c 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala
@@ -20,10 +20,11 @@ package org.apache.spark.sql.sources
 import java.io.File
 
 import org.apache.spark.sql.{AnalysisException, QueryTest}
-import org.apache.spark.sql.catalyst.expressions.UnsafeRow
+import org.apache.spark.sql.catalyst.expressions.{Murmur3Hash, UnsafeProjection, UnsafeRow}
 import org.apache.spark.sql.functions._
 import org.apache.spark.sql.hive.test.TestHiveSingleton
 import org.apache.spark.sql.test.SQLTestUtils
+import org.apache.spark.util.Utils
 
 class BucketedWriteSuite extends QueryTest with SQLTestUtils with TestHiveSingleton {
   import testImplicits._
@@ -70,6 +71,8 @@ class BucketedWriteSuite extends QueryTest with SQLTestUtils with TestHiveSingle
     }
   }
 
+  private val df = (0 until 50).map(i => (i % 5, i % 13, i.toString)).toDF("i", "j", "k")
+
   private def testBucketing(
       dataDir: File,
       source: String,
@@ -82,27 +85,30 @@ class BucketedWriteSuite extends QueryTest with SQLTestUtils with TestHiveSingle
     assert(groupedBucketFiles.size <= 8)
 
     for ((bucketId, bucketFiles) <- groupedBucketFiles) {
-      for (bucketFile <- bucketFiles) {
-        val df = sqlContext.read.format(source).load(bucketFile.getAbsolutePath)
-          .select((bucketCols ++ sortCols).map(col): _*)
+      for (bucketFilePath <- bucketFiles.map(_.getAbsolutePath)) {
+        val types = df.select((bucketCols ++ sortCols).map(col): _*).schema.map(_.dataType)
+        val columns = (bucketCols ++ sortCols).zip(types).map {
+          case (colName, dt) => col(colName).cast(dt)
+        }
+        val readBack = sqlContext.read.format(source).load(bucketFilePath).select(columns: _*)
 
         if (sortCols.nonEmpty) {
-          checkAnswer(df.sort(sortCols.map(col): _*), df.collect())
+          checkAnswer(readBack.sort(sortCols.map(col): _*), readBack.collect())
         }
 
-        val rows = df.select(bucketCols.map(col): _*).queryExecution.toRdd.map(_.copy()).collect()
+        val qe = readBack.select(bucketCols.map(col): _*).queryExecution
+        val rows = qe.toRdd.map(_.copy()).collect()
+        val getHashCode =
+          UnsafeProjection.create(new Murmur3Hash(qe.analyzed.output) :: Nil, qe.analyzed.output)
 
         for (row <- rows) {
-          assert(row.isInstanceOf[UnsafeRow])
-          val actualBucketId = (row.hashCode() % 8 + 8) % 8
+          val actualBucketId = Utils.nonNegativeMod(getHashCode(row).getInt(0), 8)
           assert(actualBucketId == bucketId)
         }
       }
     }
   }
 
-  private val df = (0 until 50).map(i => (i % 5, i % 13, i.toString)).toDF("i", "j", "k")
-
   test("write bucketed data") {
     for (source <- Seq("parquet", "json", "orc")) {
       withTable("bucketed_table") {

http://git-wip-us.apache.org/repos/asf/spark/blob/c2ea79f9/unsafe/src/main/java/org/apache/spark/unsafe/hash/Murmur3_x86_32.java
----------------------------------------------------------------------
diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/hash/Murmur3_x86_32.java b/unsafe/src/main/java/org/apache/spark/unsafe/hash/Murmur3_x86_32.java
index 4276f25..5e7ee48 100644
--- a/unsafe/src/main/java/org/apache/spark/unsafe/hash/Murmur3_x86_32.java
+++ b/unsafe/src/main/java/org/apache/spark/unsafe/hash/Murmur3_x86_32.java
@@ -38,6 +38,10 @@ public final class Murmur3_x86_32 {
   }
 
   public int hashInt(int input) {
+    return hashInt(input, seed);
+  }
+
+  public static int hashInt(int input, int seed) {
     int k1 = mixK1(input);
     int h1 = mixH1(seed, k1);
 
@@ -51,16 +55,38 @@ public final class Murmur3_x86_32 {
   public static int hashUnsafeWords(Object base, long offset, int lengthInBytes, int seed) {
     // This is based on Guava's `Murmur32_Hasher.processRemaining(ByteBuffer)` method.
     assert (lengthInBytes % 8 == 0): "lengthInBytes must be a multiple of 8 (word-aligned)";
+    int h1 = hashBytesByInt(base, offset, lengthInBytes, seed);
+    return fmix(h1, lengthInBytes);
+  }
+
+  public static int hashUnsafeBytes(Object base, long offset, int lengthInBytes, int seed) {
+    assert (lengthInBytes >= 0): "lengthInBytes cannot be negative";
+    int lengthAligned = lengthInBytes - lengthInBytes % 4;
+    int h1 = hashBytesByInt(base, offset, lengthAligned, seed);
+    for (int i = lengthAligned; i < lengthInBytes; i++) {
+      int halfWord = Platform.getByte(base, offset + i);
+      int k1 = mixK1(halfWord);
+      h1 = mixH1(h1, k1);
+    }
+    return fmix(h1, lengthInBytes);
+  }
+
+  private static int hashBytesByInt(Object base, long offset, int lengthInBytes, int seed) {
+    assert (lengthInBytes % 4 == 0);
     int h1 = seed;
     for (int i = 0; i < lengthInBytes; i += 4) {
       int halfWord = Platform.getInt(base, offset + i);
       int k1 = mixK1(halfWord);
       h1 = mixH1(h1, k1);
     }
-    return fmix(h1, lengthInBytes);
+    return h1;
   }
 
   public int hashLong(long input) {
+    return hashLong(input, seed);
+  }
+
+  public static int hashLong(long input, int seed) {
     int low = (int) input;
     int high = (int) (input >>> 32);
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org