You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by da...@apache.org on 2015/06/26 07:17:20 UTC

spark git commit: [SPARK-8620] [SQL] cleanup CodeGenContext

Repository: spark
Updated Branches:
  refs/heads/master 47c874bab -> 40360112c


[SPARK-8620] [SQL] cleanup CodeGenContext

fix docs, remove nativeTypes , use java type to get boxed type ,default value, etc. to avoid handle `DateType` and `TimestampType` as int and long again and again.

Author: Wenchen Fan <cl...@outlook.com>

Closes #7010 from cloud-fan/cg and squashes the following commits:

aa01cf9 [Wenchen Fan] cleanup CodeGenContext


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/40360112
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/40360112
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/40360112

Branch: refs/heads/master
Commit: 40360112c417b5432564f4bcb8a9100f4066b55e
Parents: 47c874b
Author: Wenchen Fan <cl...@outlook.com>
Authored: Thu Jun 25 22:16:53 2015 -0700
Committer: Davies Liu <da...@databricks.com>
Committed: Thu Jun 25 22:16:53 2015 -0700

----------------------------------------------------------------------
 .../spark/sql/catalyst/expressions/Cast.scala   |   5 +-
 .../expressions/codegen/CodeGenerator.scala     | 130 +++++++++----------
 .../codegen/GenerateProjection.scala            |  34 ++---
 .../catalyst/expressions/stringOperations.scala |   1 -
 4 files changed, 82 insertions(+), 88 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/40360112/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
index 8bd7fc1..8d66968 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
@@ -467,11 +467,8 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
         defineCodeGen(ctx, ev, c => s"!$c.isZero()")
       case (dt: NumericType, BooleanType) =>
         defineCodeGen(ctx, ev, c => s"$c != 0")
-
-      case (_: DecimalType, IntegerType) =>
-        defineCodeGen(ctx, ev, c => s"($c).toInt()")
       case (_: DecimalType, dt: NumericType) =>
-        defineCodeGen(ctx, ev, c => s"($c).to${ctx.boxedType(dt)}()")
+        defineCodeGen(ctx, ev, c => s"($c).to${ctx.primitiveTypeName(dt)}()")
       case (_: NumericType, dt: NumericType) =>
         defineCodeGen(ctx, ev, c => s"(${ctx.javaType(dt)})($c)")
 

http://git-wip-us.apache.org/repos/asf/spark/blob/40360112/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
index 47c5455..e20e3a9 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
@@ -59,6 +59,14 @@ class CodeGenContext {
   val stringType: String = classOf[UTF8String].getName
   val decimalType: String = classOf[Decimal].getName
 
+  final val JAVA_BOOLEAN = "boolean"
+  final val JAVA_BYTE = "byte"
+  final val JAVA_SHORT = "short"
+  final val JAVA_INT = "int"
+  final val JAVA_LONG = "long"
+  final val JAVA_FLOAT = "float"
+  final val JAVA_DOUBLE = "double"
+
   private val curId = new java.util.concurrent.atomic.AtomicInteger()
 
   /**
@@ -72,98 +80,94 @@ class CodeGenContext {
   }
 
   /**
-   * Return the code to access a column for given DataType
+   * Returns the code to access a column in Row for a given DataType.
    */
   def getColumn(dataType: DataType, ordinal: Int): String = {
-    if (isNativeType(dataType)) {
-      s"i.${accessorForType(dataType)}($ordinal)"
+    val jt = javaType(dataType)
+    if (isPrimitiveType(jt)) {
+      s"i.get${primitiveTypeName(jt)}($ordinal)"
     } else {
-      s"(${boxedType(dataType)})i.apply($ordinal)"
+      s"($jt)i.apply($ordinal)"
     }
   }
 
   /**
-   * Return the code to update a column in Row for given DataType
+   * Returns the code to update a column in Row for a given DataType.
    */
   def setColumn(dataType: DataType, ordinal: Int, value: String): String = {
-    if (isNativeType(dataType)) {
-      s"${mutatorForType(dataType)}($ordinal, $value)"
+    val jt = javaType(dataType)
+    if (isPrimitiveType(jt)) {
+      s"set${primitiveTypeName(jt)}($ordinal, $value)"
     } else {
       s"update($ordinal, $value)"
     }
   }
 
   /**
-   * Return the name of accessor in Row for a DataType
+   * Returns the name used in accessor and setter for a Java primitive type.
    */
-  def accessorForType(dt: DataType): String = dt match {
-    case IntegerType => "getInt"
-    case other => s"get${boxedType(dt)}"
+  def primitiveTypeName(jt: String): String = jt match {
+    case JAVA_INT => "Int"
+    case _ => boxedType(jt)
   }
 
-  /**
-   * Return the name of mutator in Row for a DataType
-   */
-  def mutatorForType(dt: DataType): String = dt match {
-    case IntegerType => "setInt"
-    case other => s"set${boxedType(dt)}"
-  }
+  def primitiveTypeName(dt: DataType): String = primitiveTypeName(javaType(dt))
 
   /**
-   * Return the Java type for a DataType
+   * Returns the Java type for a DataType.
    */
   def javaType(dt: DataType): String = dt match {
-    case IntegerType => "int"
-    case LongType => "long"
-    case ShortType => "short"
-    case ByteType => "byte"
-    case DoubleType => "double"
-    case FloatType => "float"
-    case BooleanType => "boolean"
+    case BooleanType => JAVA_BOOLEAN
+    case ByteType => JAVA_BYTE
+    case ShortType => JAVA_SHORT
+    case IntegerType => JAVA_INT
+    case LongType => JAVA_LONG
+    case FloatType => JAVA_FLOAT
+    case DoubleType => JAVA_DOUBLE
     case dt: DecimalType => decimalType
     case BinaryType => "byte[]"
     case StringType => stringType
-    case DateType => "int"
-    case TimestampType => "long"
+    case DateType => JAVA_INT
+    case TimestampType => JAVA_LONG
     case dt: OpenHashSetUDT if dt.elementType == IntegerType => classOf[IntegerHashSet].getName
     case dt: OpenHashSetUDT if dt.elementType == LongType => classOf[LongHashSet].getName
     case _ => "Object"
   }
 
   /**
-   * Return the boxed type in Java
+   * Returns the boxed type in Java.
    */
-  def boxedType(dt: DataType): String = dt match {
-    case IntegerType => "Integer"
-    case LongType => "Long"
-    case ShortType => "Short"
-    case ByteType => "Byte"
-    case DoubleType => "Double"
-    case FloatType => "Float"
-    case BooleanType => "Boolean"
-    case DateType => "Integer"
-    case TimestampType => "Long"
-    case _ => javaType(dt)
+  def boxedType(jt: String): String = jt match {
+    case JAVA_BOOLEAN => "Boolean"
+    case JAVA_BYTE => "Byte"
+    case JAVA_SHORT => "Short"
+    case JAVA_INT => "Integer"
+    case JAVA_LONG => "Long"
+    case JAVA_FLOAT => "Float"
+    case JAVA_DOUBLE => "Double"
+    case other => other
   }
 
+  def boxedType(dt: DataType): String = boxedType(javaType(dt))
+
   /**
-   * Return the representation of default value for given DataType
+   * Returns the representation of default value for a given Java Type.
    */
-  def defaultValue(dt: DataType): String = dt match {
-    case BooleanType => "false"
-    case FloatType => "-1.0f"
-    case ShortType => "(short)-1"
-    case LongType => "-1L"
-    case ByteType => "(byte)-1"
-    case DoubleType => "-1.0"
-    case IntegerType => "-1"
-    case DateType => "-1"
-    case TimestampType => "-1L"
+  def defaultValue(jt: String): String = jt match {
+    case JAVA_BOOLEAN => "false"
+    case JAVA_BYTE => "(byte)-1"
+    case JAVA_SHORT => "(short)-1"
+    case JAVA_INT => "-1"
+    case JAVA_LONG => "-1L"
+    case JAVA_FLOAT => "-1.0f"
+    case JAVA_DOUBLE => "-1.0"
     case _ => "null"
   }
 
+  def defaultValue(dt: DataType): String = defaultValue(javaType(dt))
+
   /**
-   * Generate code for equal expression in Java
+   * Generates code for equal expression in Java.
    */
   def genEqual(dataType: DataType, c1: String, c2: String): String = dataType match {
     case BinaryType => s"java.util.Arrays.equals($c1, $c2)"
@@ -172,7 +176,7 @@ class CodeGenContext {
   }
 
   /**
-   * Generate code for compare expression in Java
+   * Generates code for compare expression in Java.
    */
   def genComp(dataType: DataType, c1: String, c2: String): String = dataType match {
     // java boolean doesn't support > or < operator
@@ -184,25 +188,17 @@ class CodeGenContext {
   }
 
   /**
-   * List of data types that have special accessors and setters in [[InternalRow]].
+   * List of java data types that have special accessors and setters in [[InternalRow]].
    */
-  val nativeTypes =
-    Seq(IntegerType, BooleanType, LongType, DoubleType, FloatType, ShortType, ByteType)
+  val primitiveTypes =
+    Seq(JAVA_BOOLEAN, JAVA_BYTE, JAVA_SHORT, JAVA_INT, JAVA_LONG, JAVA_FLOAT, JAVA_DOUBLE)
 
   /**
-   * Returns true if the data type has a special accessor and setter in [[InternalRow]].
+   * Returns true if the Java type has a special accessor and setter in [[InternalRow]].
    */
-  def isNativeType(dt: DataType): Boolean = nativeTypes.contains(dt)
+  def isPrimitiveType(jt: String): Boolean = primitiveTypes.contains(jt)
 
-  /**
-   * List of data types who's Java type is primitive type
-   */
-  val primitiveTypes = nativeTypes ++ Seq(DateType, TimestampType)
-
-  /**
-   * Returns true if the Java type is primitive type
-   */
-  def isPrimitiveType(dt: DataType): Boolean = primitiveTypes.contains(dt)
+  def isPrimitiveType(dt: DataType): Boolean = isPrimitiveType(javaType(dt))
 }
 
 

http://git-wip-us.apache.org/repos/asf/spark/blob/40360112/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala
index e362625..624e1cf 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala
@@ -72,54 +72,56 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] {
       s"case $i: { c$i = (${ctx.boxedType(e.dataType)})value; return;}"
     }.mkString("\n        ")
 
-    val specificAccessorFunctions = ctx.nativeTypes.map { dataType =>
+    val specificAccessorFunctions = ctx.primitiveTypes.map { jt =>
       val cases = expressions.zipWithIndex.flatMap {
-        case (e, i) if ctx.javaType(e.dataType) == ctx.javaType(dataType) =>
-          List(s"case $i: return c$i;")
-        case _ => Nil
+        case (e, i) if ctx.javaType(e.dataType) == jt =>
+          Some(s"case $i: return c$i;")
+        case _ => None
       }.mkString("\n        ")
       if (cases.length > 0) {
+        val getter = "get" + ctx.primitiveTypeName(jt)
         s"""
       @Override
-      public ${ctx.javaType(dataType)} ${ctx.accessorForType(dataType)}(int i) {
+      public $jt $getter(int i) {
         if (isNullAt(i)) {
-          return ${ctx.defaultValue(dataType)};
+          return ${ctx.defaultValue(jt)};
         }
         switch (i) {
         $cases
         }
         throw new IllegalArgumentException("Invalid index: " + i
-          + " in ${ctx.accessorForType(dataType)}");
+          + " in $getter");
       }"""
       } else {
         ""
       }
-    }.mkString("\n")
+    }.filter(_.length > 0).mkString("\n")
 
-    val specificMutatorFunctions = ctx.nativeTypes.map { dataType =>
+    val specificMutatorFunctions = ctx.primitiveTypes.map { jt =>
       val cases = expressions.zipWithIndex.flatMap {
-        case (e, i) if ctx.javaType(e.dataType) == ctx.javaType(dataType) =>
-          List(s"case $i: { c$i = value; return; }")
-        case _ => Nil
+        case (e, i) if ctx.javaType(e.dataType) == jt =>
+          Some(s"case $i: { c$i = value; return; }")
+        case _ => None
       }.mkString("\n        ")
       if (cases.length > 0) {
+        val setter = "set" + ctx.primitiveTypeName(jt)
         s"""
       @Override
-      public void ${ctx.mutatorForType(dataType)}(int i, ${ctx.javaType(dataType)} value) {
+      public void $setter(int i, $jt value) {
         nullBits[i] = false;
         switch (i) {
         $cases
         }
         throw new IllegalArgumentException("Invalid index: " + i +
-          " in ${ctx.mutatorForType(dataType)}");
+          " in $setter}");
       }"""
       } else {
         ""
       }
-    }.mkString("\n")
+    }.filter(_.length > 0).mkString("\n")
 
     val hashValues = expressions.zipWithIndex.map { case (e, i) =>
-      val col = newTermName(s"c$i")
+      val col = s"c$i"
       val nonNull = e.dataType match {
         case BooleanType => s"$col ? 0 : 1"
         case ByteType | ShortType | IntegerType | DateType => s"$col"

http://git-wip-us.apache.org/repos/asf/spark/blob/40360112/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
index 44416e7..a6225fd 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
@@ -20,7 +20,6 @@ package org.apache.spark.sql.catalyst.expressions
 import java.util.regex.Pattern
 
 import org.apache.spark.sql.catalyst.analysis.UnresolvedException
-import org.apache.spark.sql.catalyst.expressions.Substring
 import org.apache.spark.sql.catalyst.expressions.codegen._
 import org.apache.spark.sql.types._
 import org.apache.spark.unsafe.types.UTF8String


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org