You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by da...@apache.org on 2016/01/29 19:24:29 UTC

spark git commit: [SPARK-13072] [SQL] simplify and improve murmur3 hash expression codegen

Repository: spark
Updated Branches:
  refs/heads/master e4c1162b6 -> c5f745ede


[SPARK-13072] [SQL] simplify and improve murmur3 hash expression codegen

simplify(remove several unnecessary local variables) the generated code of hash expression, and avoid null check if possible.

generated code comparison for `hash(int, double, string, array<string>)`:
**before:**
```
  public UnsafeRow apply(InternalRow i) {
    /* hash(input[0, int],input[1, double],input[2, string],input[3, array<int>],42) */
    int value1 = 42;
    /* input[0, int] */
    int value3 = i.getInt(0);
    if (!false) {
      value1 = org.apache.spark.unsafe.hash.Murmur3_x86_32.hashInt(value3, value1);
    }
    /* input[1, double] */
    double value5 = i.getDouble(1);
    if (!false) {
      value1 = org.apache.spark.unsafe.hash.Murmur3_x86_32.hashLong(Double.doubleToLongBits(value5), value1);
    }
    /* input[2, string] */
    boolean isNull6 = i.isNullAt(2);
    UTF8String value7 = isNull6 ? null : (i.getUTF8String(2));
    if (!isNull6) {
      value1 = org.apache.spark.unsafe.hash.Murmur3_x86_32.hashUnsafeBytes(value7.getBaseObject(), value7.getBaseOffset(), value7.numBytes(), value1);
    }
    /* input[3, array<int>] */
    boolean isNull8 = i.isNullAt(3);
    ArrayData value9 = isNull8 ? null : (i.getArray(3));
    if (!isNull8) {
      int result10 = value1;
      for (int index11 = 0; index11 < value9.numElements(); index11++) {
        if (!value9.isNullAt(index11)) {
          final int element12 = value9.getInt(index11);
          result10 = org.apache.spark.unsafe.hash.Murmur3_x86_32.hashInt(element12, result10);
        }
      }
      value1 = result10;
    }
  }
```
**after:**
```
  public UnsafeRow apply(InternalRow i) {
    /* hash(input[0, int],input[1, double],input[2, string],input[3, array<int>],42) */
    int value1 = 42;
    /* input[0, int] */
    int value3 = i.getInt(0);
    value1 = org.apache.spark.unsafe.hash.Murmur3_x86_32.hashInt(value3, value1);
    /* input[1, double] */
    double value5 = i.getDouble(1);
    value1 = org.apache.spark.unsafe.hash.Murmur3_x86_32.hashLong(Double.doubleToLongBits(value5), value1);
    /* input[2, string] */
    boolean isNull6 = i.isNullAt(2);
    UTF8String value7 = isNull6 ? null : (i.getUTF8String(2));

    if (!isNull6) {
      value1 = org.apache.spark.unsafe.hash.Murmur3_x86_32.hashUnsafeBytes(value7.getBaseObject(), value7.getBaseOffset(), value7.numBytes(), value1);
    }

    /* input[3, array<int>] */
    boolean isNull8 = i.isNullAt(3);
    ArrayData value9 = isNull8 ? null : (i.getArray(3));
    if (!isNull8) {
      for (int index10 = 0; index10 < value9.numElements(); index10++) {
        final int element11 = value9.getInt(index10);
        value1 = org.apache.spark.unsafe.hash.Murmur3_x86_32.hashInt(element11, value1);
      }
    }

    rowWriter14.write(0, value1);
    return result12;
  }
```

Author: Wenchen Fan <we...@databricks.com>

Closes #10974 from cloud-fan/codegen.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/c5f745ed
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/c5f745ed
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/c5f745ed

Branch: refs/heads/master
Commit: c5f745ede01831b59c57effa7de88c648b82c13d
Parents: e4c1162
Author: Wenchen Fan <we...@databricks.com>
Authored: Fri Jan 29 10:24:23 2016 -0800
Committer: Davies Liu <da...@gmail.com>
Committed: Fri Jan 29 10:24:23 2016 -0800

----------------------------------------------------------------------
 .../spark/sql/catalyst/expressions/misc.scala   | 155 +++++++++----------
 1 file changed, 69 insertions(+), 86 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/c5f745ed/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala
index 493e0aa..8480c3f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala
@@ -325,36 +325,62 @@ case class Murmur3Hash(children: Seq[Expression], seed: Int) extends Expression
 
   override def genCode(ctx: CodegenContext, ev: ExprCode): String = {
     ev.isNull = "false"
-    val childrenHash = children.zipWithIndex.map {
-      case (child, dt) =>
-        val childGen = child.gen(ctx)
-        val childHash = computeHash(childGen.value, child.dataType, ev.value, ctx)
-        s"""
-          ${childGen.code}
-          if (!${childGen.isNull}) {
-            ${childHash.code}
-            ${ev.value} = ${childHash.value};
-          }
-        """
+    val childrenHash = children.map { child =>
+      val childGen = child.gen(ctx)
+      childGen.code + generateNullCheck(child.nullable, childGen.isNull) {
+        computeHash(childGen.value, child.dataType, ev.value, ctx)
+      }
     }.mkString("\n")
+
     s"""
       int ${ev.value} = $seed;
       $childrenHash
     """
   }
 
+  private def generateNullCheck(nullable: Boolean, isNull: String)(execution: String): String = {
+    if (nullable) {
+      s"""
+        if (!$isNull) {
+          $execution
+        }
+      """
+    } else {
+      "\n" + execution
+    }
+  }
+
+  private def nullSafeElementHash(
+      input: String,
+      index: String,
+      nullable: Boolean,
+      elementType: DataType,
+      result: String,
+      ctx: CodegenContext): String = {
+    val element = ctx.freshName("element")
+
+    generateNullCheck(nullable, s"$input.isNullAt($index)") {
+      s"""
+        final ${ctx.javaType(elementType)} $element = ${ctx.getValue(input, elementType, index)};
+        ${computeHash(element, elementType, result, ctx)}
+      """
+    }
+  }
+
   private def computeHash(
       input: String,
       dataType: DataType,
-      seed: String,
-      ctx: CodegenContext): ExprCode = {
+      result: String,
+      ctx: CodegenContext): String = {
     val hasher = classOf[Murmur3_x86_32].getName
-    def hashInt(i: String): ExprCode = inlineValue(s"$hasher.hashInt($i, $seed)")
-    def hashLong(l: String): ExprCode = inlineValue(s"$hasher.hashLong($l, $seed)")
-    def inlineValue(v: String): ExprCode = ExprCode(code = "", isNull = "false", value = v)
+
+    def hashInt(i: String): String = s"$result = $hasher.hashInt($i, $result);"
+    def hashLong(l: String): String = s"$result = $hasher.hashLong($l, $result);"
+    def hashBytes(b: String): String =
+      s"$result = $hasher.hashUnsafeBytes($b, Platform.BYTE_ARRAY_OFFSET, $b.length, $result);"
 
     dataType match {
-      case NullType => inlineValue(seed)
+      case NullType => ""
       case BooleanType => hashInt(s"$input ? 1 : 0")
       case ByteType | ShortType | IntegerType | DateType => hashInt(input)
       case LongType | TimestampType => hashLong(input)
@@ -365,91 +391,48 @@ case class Murmur3Hash(children: Seq[Expression], seed: Int) extends Expression
           hashLong(s"$input.toUnscaledLong()")
         } else {
           val bytes = ctx.freshName("bytes")
-          val code = s"byte[] $bytes = $input.toJavaBigDecimal().unscaledValue().toByteArray();"
-          val offset = "Platform.BYTE_ARRAY_OFFSET"
-          val result = s"$hasher.hashUnsafeBytes($bytes, $offset, $bytes.length, $seed)"
-          ExprCode(code, "false", result)
+          s"""
+            final byte[] $bytes = $input.toJavaBigDecimal().unscaledValue().toByteArray();
+            ${hashBytes(bytes)}
+          """
         }
       case CalendarIntervalType =>
-        val microsecondsHash = s"$hasher.hashLong($input.microseconds, $seed)"
-        val monthsHash = s"$hasher.hashInt($input.months, $microsecondsHash)"
-        inlineValue(monthsHash)
-      case BinaryType =>
-        val offset = "Platform.BYTE_ARRAY_OFFSET"
-        inlineValue(s"$hasher.hashUnsafeBytes($input, $offset, $input.length, $seed)")
+        val microsecondsHash = s"$hasher.hashLong($input.microseconds, $result)"
+        s"$result = $hasher.hashInt($input.months, $microsecondsHash);"
+      case BinaryType => hashBytes(input)
       case StringType =>
         val baseObject = s"$input.getBaseObject()"
         val baseOffset = s"$input.getBaseOffset()"
         val numBytes = s"$input.numBytes()"
-        inlineValue(s"$hasher.hashUnsafeBytes($baseObject, $baseOffset, $numBytes, $seed)")
+        s"$result = $hasher.hashUnsafeBytes($baseObject, $baseOffset, $numBytes, $result);"
 
-      case ArrayType(et, _) =>
-        val result = ctx.freshName("result")
+      case ArrayType(et, containsNull) =>
         val index = ctx.freshName("index")
-        val element = ctx.freshName("element")
-        val elementHash = computeHash(element, et, result, ctx)
-        val code =
-          s"""
-            int $result = $seed;
-            for (int $index = 0; $index < $input.numElements(); $index++) {
-              if (!$input.isNullAt($index)) {
-                final ${ctx.javaType(et)} $element = ${ctx.getValue(input, et, index)};
-                ${elementHash.code}
-                $result = ${elementHash.value};
-              }
-            }
-          """
-        ExprCode(code, "false", result)
+        s"""
+          for (int $index = 0; $index < $input.numElements(); $index++) {
+            ${nullSafeElementHash(input, index, containsNull, et, result, ctx)}
+          }
+        """
 
-      case MapType(kt, vt, _) =>
-        val result = ctx.freshName("result")
+      case MapType(kt, vt, valueContainsNull) =>
         val index = ctx.freshName("index")
         val keys = ctx.freshName("keys")
         val values = ctx.freshName("values")
-        val key = ctx.freshName("key")
-        val value = ctx.freshName("value")
-        val keyHash = computeHash(key, kt, result, ctx)
-        val valueHash = computeHash(value, vt, result, ctx)
-        val code =
-          s"""
-            int $result = $seed;
-            final ArrayData $keys = $input.keyArray();
-            final ArrayData $values = $input.valueArray();
-            for (int $index = 0; $index < $input.numElements(); $index++) {
-              final ${ctx.javaType(kt)} $key = ${ctx.getValue(keys, kt, index)};
-              ${keyHash.code}
-              $result = ${keyHash.value};
-              if (!$values.isNullAt($index)) {
-                final ${ctx.javaType(vt)} $value = ${ctx.getValue(values, vt, index)};
-                ${valueHash.code}
-                $result = ${valueHash.value};
-              }
-            }
-          """
-        ExprCode(code, "false", result)
+        s"""
+          final ArrayData $keys = $input.keyArray();
+          final ArrayData $values = $input.valueArray();
+          for (int $index = 0; $index < $input.numElements(); $index++) {
+            ${nullSafeElementHash(keys, index, false, kt, result, ctx)}
+            ${nullSafeElementHash(values, index, valueContainsNull, vt, result, ctx)}
+          }
+        """
 
       case StructType(fields) =>
-        val result = ctx.freshName("result")
-        val fieldsHash = fields.map(_.dataType).zipWithIndex.map {
-          case (dt, index) =>
-            val field = ctx.freshName("field")
-            val fieldHash = computeHash(field, dt, result, ctx)
-            s"""
-              if (!$input.isNullAt($index)) {
-                final ${ctx.javaType(dt)} $field = ${ctx.getValue(input, dt, index.toString)};
-                ${fieldHash.code}
-                $result = ${fieldHash.value};
-              }
-            """
+        fields.zipWithIndex.map { case (field, index) =>
+          nullSafeElementHash(input, index.toString, field.nullable, field.dataType, result, ctx)
         }.mkString("\n")
-        val code =
-          s"""
-            int $result = $seed;
-            $fieldsHash
-          """
-        ExprCode(code, "false", result)
 
-      case udt: UserDefinedType[_] => computeHash(input, udt.sqlType, seed, ctx)
+      case udt: UserDefinedType[_] => computeHash(input, udt.sqlType, result, ctx)
     }
   }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org