You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by da...@apache.org on 2015/09/09 19:57:35 UTC

spark git commit: [SPARK-10461] [SQL] make sure `input.primitive` is always variable name not code at `GenerateUnsafeProjection`

Repository: spark
Updated Branches:
  refs/heads/master c0052d8d0 -> 71da1633c


[SPARK-10461] [SQL] make sure `input.primitive` is always variable name not code at `GenerateUnsafeProjection`

When we generate unsafe code inside `createCodeForXXX`, we always assign the `input.primitive` to a temp variable in case `input.primitive` is expression code.

This PR did some refactor to make sure `input.primitive` is always variable name, and some other typo and style fixes.

Author: Wenchen Fan <cl...@outlook.com>

Closes #8613 from cloud-fan/minor.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/71da1633
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/71da1633
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/71da1633

Branch: refs/heads/master
Commit: 71da1633c4dcc4e748fbe3b3236af90032b695ae
Parents: c0052d8
Author: Wenchen Fan <cl...@outlook.com>
Authored: Wed Sep 9 10:57:29 2015 -0700
Committer: Davies Liu <da...@gmail.com>
Committed: Wed Sep 9 10:57:29 2015 -0700

----------------------------------------------------------------------
 .../expressions/codegen/CodeGenerator.scala     | 10 +--
 .../codegen/GenerateMutableProjection.scala     |  2 -
 .../codegen/GenerateProjection.scala            | 12 +--
 .../codegen/GenerateUnsafeProjection.scala      | 85 +++++++++++---------
 .../expressions/complexTypeCreator.scala        | 33 ++++----
 5 files changed, 75 insertions(+), 67 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/71da1633/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
index bf96248..da3103b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
@@ -84,11 +84,11 @@ class CodeGenContext {
   /**
    * Holding all the functions those will be added into generated class.
    */
-  val addedFuntions: mutable.Map[String, String] =
+  val addedFunctions: mutable.Map[String, String] =
     mutable.Map.empty[String, String]
 
   def addNewFunction(funcName: String, funcCode: String): Unit = {
-    addedFuntions += ((funcName, funcCode))
+    addedFunctions += ((funcName, funcCode))
   }
 
   final val JAVA_BOOLEAN = "boolean"
@@ -298,8 +298,8 @@ class CodeGenContext {
            |  $body
            |}
          """.stripMargin
-         addNewFunction(name, code)
-         name
+        addNewFunction(name, code)
+        name
       }
 
       functions.map(name => s"$name($row);").mkString("\n")
@@ -337,7 +337,7 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin
   }
 
   protected def declareAddedFunctions(ctx: CodeGenContext): String = {
-    ctx.addedFuntions.map { case (funcName, funcCode) => funcCode }.mkString("\n")
+    ctx.addedFunctions.map { case (funcName, funcCode) => funcCode }.mkString("\n")
   }
 
   /**

http://git-wip-us.apache.org/repos/asf/spark/blob/71da1633/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala
index b4d4df8..793023b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala
@@ -17,8 +17,6 @@
 
 package org.apache.spark.sql.catalyst.expressions.codegen
 
-import scala.collection.mutable.ArrayBuffer
-
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.expressions.aggregate.NoOp
 import org.apache.spark.sql.types.DecimalType

http://git-wip-us.apache.org/repos/asf/spark/blob/71da1633/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala
index c744e84..2164ddf 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala
@@ -48,7 +48,7 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] {
     val columns = expressions.zipWithIndex.map {
       case (e, i) =>
         s"private ${ctx.javaType(e.dataType)} c$i = ${ctx.defaultValue(e.dataType)};\n"
-    }.mkString("\n      ")
+    }.mkString("\n")
 
     val initColumns = expressions.zipWithIndex.map {
       case (e, i) =>
@@ -67,18 +67,18 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] {
 
     val getCases = (0 until expressions.size).map { i =>
       s"case $i: return c$i;"
-    }.mkString("\n        ")
+    }.mkString("\n")
 
     val updateCases = expressions.zipWithIndex.map { case (e, i) =>
       s"case $i: { c$i = (${ctx.boxedType(e.dataType)})value; return;}"
-    }.mkString("\n        ")
+    }.mkString("\n")
 
     val specificAccessorFunctions = ctx.primitiveTypes.map { jt =>
       val cases = expressions.zipWithIndex.flatMap {
         case (e, i) if ctx.javaType(e.dataType) == jt =>
           Some(s"case $i: return c$i;")
         case _ => None
-      }.mkString("\n        ")
+      }.mkString("\n")
       if (cases.length > 0) {
         val getter = "get" + ctx.primitiveTypeName(jt)
         s"""
@@ -103,7 +103,7 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] {
         case (e, i) if ctx.javaType(e.dataType) == jt =>
           Some(s"case $i: { c$i = value; return; }")
         case _ => None
-      }.mkString("\n        ")
+      }.mkString("\n")
       if (cases.length > 0) {
         val setter = "set" + ctx.primitiveTypeName(jt)
         s"""
@@ -152,7 +152,7 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] {
 
     val copyColumns = expressions.zipWithIndex.map { case (e, i) =>
         s"""if (!nullBits[$i]) arr[$i] = c$i;"""
-    }.mkString("\n      ")
+    }.mkString("\n")
 
     val code = s"""
     public SpecificProjection generate($exprType[] expr) {

http://git-wip-us.apache.org/repos/asf/spark/blob/71da1633/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
index b570fe8..03c5f44 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
@@ -134,7 +134,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
     ctx.addMutableState("byte[]", buffer, s"this.$buffer = new byte[$fixedSize];")
     val cursor = ctx.freshName("cursor")
     ctx.addMutableState("int", cursor, s"this.$cursor = 0;")
-    val tmp = ctx.freshName("tmpBuffer")
+    val tmpBuffer = ctx.freshName("tmpBuffer")
 
     val convertedFields = inputTypes.zip(inputs).zipWithIndex.map { case ((dt, input), i) =>
       val ev = createConvertCode(ctx, input, dt)
@@ -144,10 +144,10 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
           int $numBytes = $cursor + (${genAdditionalSize(dt, ev)});
           if ($buffer.length < $numBytes) {
             // This will not happen frequently, because the buffer is re-used.
-            byte[] $tmp = new byte[$numBytes * 2];
+            byte[] $tmpBuffer = new byte[$numBytes * 2];
             Platform.copyMemory($buffer, Platform.BYTE_ARRAY_OFFSET,
-              $tmp, Platform.BYTE_ARRAY_OFFSET, $buffer.length);
-            $buffer = $tmp;
+              $tmpBuffer, Platform.BYTE_ARRAY_OFFSET, $buffer.length);
+            $buffer = $tmpBuffer;
           }
           $output.pointTo($buffer, Platform.BYTE_ARRAY_OFFSET, ${inputTypes.length}, $numBytes);
          """
@@ -207,20 +207,22 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
     val buffer = ctx.freshName("buffer")
     ctx.addMutableState("byte[]", buffer, s"$buffer = new byte[64];")
     val outputIsNull = ctx.freshName("isNull")
-    val tmp = ctx.freshName("tmp")
     val numElements = ctx.freshName("numElements")
     val fixedSize = ctx.freshName("fixedSize")
     val numBytes = ctx.freshName("numBytes")
     val elements = ctx.freshName("elements")
     val cursor = ctx.freshName("cursor")
     val index = ctx.freshName("index")
+    val elementName = ctx.freshName("elementName")
 
-    val element = GeneratedExpressionCode(
-      code = "",
-      isNull = s"$tmp.isNullAt($index)",
-      primitive = s"${ctx.getValue(tmp, elementType, index)}"
-    )
-    val convertedElement: GeneratedExpressionCode = createConvertCode(ctx, element, elementType)
+    val element = {
+      val code = s"${ctx.javaType(elementType)} $elementName = " +
+        s"${ctx.getValue(input.primitive, elementType, index)};"
+      val isNull = s"${input.primitive}.isNullAt($index)"
+      GeneratedExpressionCode(code, isNull, elementName)
+    }
+
+    val convertedElement = createConvertCode(ctx, element, elementType)
 
     // go through the input array to calculate how many bytes we need.
     val calculateNumBytes = elementType match {
@@ -272,6 +274,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
         // Should we do word align?
         val elementSize = elementType.defaultSize
         s"""
+          ${convertedElement.code}
           Platform.put${ctx.primitiveTypeName(elementType)}(
             $buffer,
             Platform.BYTE_ARRAY_OFFSET + $cursor,
@@ -280,6 +283,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
         """
       case t: DecimalType if t.precision <= Decimal.MAX_LONG_DIGITS =>
         s"""
+          ${convertedElement.code}
           Platform.putLong(
             $buffer,
             Platform.BYTE_ARRAY_OFFSET + $cursor,
@@ -307,11 +311,10 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
       ${input.code}
       final boolean $outputIsNull = ${input.isNull};
       if (!$outputIsNull) {
-        final ArrayData $tmp = ${input.primitive};
-        if ($tmp instanceof UnsafeArrayData) {
-          $output = (UnsafeArrayData) $tmp;
+        if (${input.primitive} instanceof UnsafeArrayData) {
+          $output = (UnsafeArrayData) ${input.primitive};
         } else {
-          final int $numElements = $tmp.numElements();
+          final int $numElements = ${input.primitive}.numElements();
           final int $fixedSize = 4 * $numElements;
           int $numBytes = $fixedSize;
 
@@ -350,29 +353,31 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
       valueType: DataType): GeneratedExpressionCode = {
     val output = ctx.freshName("convertedMap")
     val outputIsNull = ctx.freshName("isNull")
-    val tmp = ctx.freshName("tmp")
-
-    val keyArray = GeneratedExpressionCode(
-      code = "",
-      isNull = "false",
-      primitive = s"$tmp.keyArray()"
-    )
-    val valueArray = GeneratedExpressionCode(
-      code = "",
-      isNull = "false",
-      primitive = s"$tmp.valueArray()"
-    )
-    val convertedKeys: GeneratedExpressionCode = createCodeForArray(ctx, keyArray, keyType)
-    val convertedValues: GeneratedExpressionCode = createCodeForArray(ctx, valueArray, valueType)
+    val keyArrayName = ctx.freshName("keyArrayName")
+    val valueArrayName = ctx.freshName("valueArrayName")
+
+    val keyArray = {
+      val code = s"ArrayData $keyArrayName = ${input.primitive}.keyArray();"
+      val isNull = "false"
+      GeneratedExpressionCode(code, isNull, keyArrayName)
+    }
+
+    val valueArray = {
+      val code = s"ArrayData $valueArrayName = ${input.primitive}.valueArray();"
+      val isNull = "false"
+      GeneratedExpressionCode(code, isNull, valueArrayName)
+    }
+
+    val convertedKeys = createCodeForArray(ctx, keyArray, keyType)
+    val convertedValues = createCodeForArray(ctx, valueArray, valueType)
 
     val code = s"""
       ${input.code}
       final boolean $outputIsNull = ${input.isNull};
       UnsafeMapData $output = null;
       if (!$outputIsNull) {
-        final MapData $tmp = ${input.primitive};
-        if ($tmp instanceof UnsafeMapData) {
-          $output = (UnsafeMapData) $tmp;
+        if (${input.primitive} instanceof UnsafeMapData) {
+          $output = (UnsafeMapData) ${input.primitive};
         } else {
           ${convertedKeys.code}
           ${convertedValues.code}
@@ -393,22 +398,22 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
     case t: StructType =>
       val output = ctx.freshName("convertedStruct")
       val outputIsNull = ctx.freshName("isNull")
-      val tmp = ctx.freshName("tmp")
       val fieldTypes = t.fields.map(_.dataType)
       val fieldEvals = fieldTypes.zipWithIndex.map { case (dt, i) =>
-        val getFieldCode = ctx.getValue(tmp, dt, i.toString)
-        val fieldIsNull = s"$tmp.isNullAt($i)"
-        GeneratedExpressionCode("", fieldIsNull, getFieldCode)
+        val fieldName = ctx.freshName("fieldName")
+        val code = s"${ctx.javaType(dt)} $fieldName = " +
+          s"${ctx.getValue(input.primitive, dt, i.toString)};"
+        val isNull = s"${input.primitive}.isNullAt($i)"
+        GeneratedExpressionCode(code, isNull, fieldName)
       }
-      val converter = createCodeForStruct(ctx, tmp, fieldEvals, fieldTypes)
+      val converter = createCodeForStruct(ctx, input.primitive, fieldEvals, fieldTypes)
       val code = s"""
         ${input.code}
          UnsafeRow $output = null;
          final boolean $outputIsNull = ${input.isNull};
          if (!$outputIsNull) {
-           final InternalRow $tmp = ${input.primitive};
-           if ($tmp instanceof UnsafeRow) {
-             $output = (UnsafeRow) $tmp;
+           if (${input.primitive} instanceof UnsafeRow) {
+             $output = (UnsafeRow) ${input.primitive};
            } else {
              ${converter.code}
              $output = ${converter.primitive};

http://git-wip-us.apache.org/repos/asf/spark/blob/71da1633/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
index 1c54671..82eab5f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
@@ -48,21 +48,22 @@ case class CreateArray(children: Seq[Expression]) extends Expression {
 
   override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
     val arrayClass = classOf[GenericArrayData].getName
+    val values = ctx.freshName("values")
     s"""
       final boolean ${ev.isNull} = false;
-      final Object[] values = new Object[${children.size}];
+      final Object[] $values = new Object[${children.size}];
     """ +
       children.zipWithIndex.map { case (e, i) =>
         val eval = e.gen(ctx)
         eval.code + s"""
           if (${eval.isNull}) {
-            values[$i] = null;
+            $values[$i] = null;
           } else {
-            values[$i] = ${eval.primitive};
+            $values[$i] = ${eval.primitive};
           }
          """
       }.mkString("\n") +
-      s"final ${ctx.javaType(dataType)} ${ev.primitive} = new $arrayClass(values);"
+      s"final ArrayData ${ev.primitive} = new $arrayClass($values);"
   }
 
   override def prettyName: String = "array"
@@ -94,21 +95,23 @@ case class CreateStruct(children: Seq[Expression]) extends Expression {
   }
 
   override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
-    val rowClass = classOf[GenericMutableRow].getName
+    val rowClass = classOf[GenericInternalRow].getName
+    val values = ctx.freshName("values")
     s"""
       boolean ${ev.isNull} = false;
-      final $rowClass ${ev.primitive} = new $rowClass(${children.size});
+      final Object[] $values = new Object[${children.size}];
     """ +
       children.zipWithIndex.map { case (e, i) =>
         val eval = e.gen(ctx)
         eval.code + s"""
           if (${eval.isNull}) {
-            ${ev.primitive}.update($i, null);
+            $values[$i] = null;
           } else {
-            ${ev.primitive}.update($i, ${eval.primitive});
+            $values[$i] = ${eval.primitive};
           }
          """
-      }.mkString("\n")
+      }.mkString("\n") +
+      s"final InternalRow ${ev.primitive} = new $rowClass($values);"
   }
 
   override def prettyName: String = "struct"
@@ -161,21 +164,23 @@ case class CreateNamedStruct(children: Seq[Expression]) extends Expression {
   }
 
   override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
-    val rowClass = classOf[GenericMutableRow].getName
+    val rowClass = classOf[GenericInternalRow].getName
+    val values = ctx.freshName("values")
     s"""
       boolean ${ev.isNull} = false;
-      final $rowClass ${ev.primitive} = new $rowClass(${valExprs.size});
+      final Object[] $values = new Object[${valExprs.size}];
     """ +
       valExprs.zipWithIndex.map { case (e, i) =>
         val eval = e.gen(ctx)
         eval.code + s"""
           if (${eval.isNull}) {
-            ${ev.primitive}.update($i, null);
+            $values[$i] = null;
           } else {
-            ${ev.primitive}.update($i, ${eval.primitive});
+            $values[$i] = ${eval.primitive};
           }
          """
-      }.mkString("\n")
+      }.mkString("\n") +
+      s"final InternalRow ${ev.primitive} = new $rowClass($values);"
   }
 
   override def prettyName: String = "named_struct"


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org