You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by rx...@apache.org on 2015/07/17 10:27:19 UTC

spark git commit: [SPARK-9022] [SQL] Generated projections for UnsafeRow

Repository: spark
Updated Branches:
  refs/heads/master 5a3c1ad08 -> ec8973d12


[SPARK-9022] [SQL] Generated projections for UnsafeRow

Added two projections: GenerateUnsafeProjection and FromUnsafeProjection, which could be used to convert UnsafeRow from/to GenericInternalRow.

They will re-use the buffer during projection, similar to MutableProjection (without all the interface MutableProjection has).

cc rxin JoshRosen

Author: Davies Liu <da...@databricks.com>

Closes #7437 from davies/unsafe_proj2 and squashes the following commits:

dbf538e [Davies Liu] test with all the expression (only for supported types)
dc737b2 [Davies Liu] address comment
e424520 [Davies Liu] fix scala style
70e231c [Davies Liu] address comments
729138d [Davies Liu] Merge branch 'master' of github.com:apache/spark into unsafe_proj2
5a26373 [Davies Liu] unsafe projections


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/ec8973d1
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/ec8973d1
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/ec8973d1

Branch: refs/heads/master
Commit: ec8973d1245d4a99edeb7365d7f4b0063ac31ddf
Parents: 5a3c1ad
Author: Davies Liu <da...@databricks.com>
Authored: Fri Jul 17 01:27:14 2015 -0700
Committer: Reynold Xin <rx...@databricks.com>
Committed: Fri Jul 17 01:27:14 2015 -0700

----------------------------------------------------------------------
 .../sql/execution/UnsafeExternalRowSorter.java  |  27 ++--
 .../spark/sql/catalyst/expressions/Cast.scala   |   8 +-
 .../sql/catalyst/expressions/Projection.scala   |  35 ++++++
 .../expressions/UnsafeRowConverter.scala        |  69 +++++-----
 .../expressions/codegen/CodeGenerator.scala     |  15 ++-
 .../codegen/GenerateProjection.scala            |   4 +-
 .../codegen/GenerateUnsafeProjection.scala      | 125 +++++++++++++++++++
 .../catalyst/expressions/decimalFunctions.scala |   2 +-
 .../spark/sql/catalyst/expressions/math.scala   |   2 +-
 .../spark/sql/catalyst/expressions/misc.scala   |  17 ++-
 .../expressions/ExpressionEvalHelper.scala      |  34 ++++-
 11 files changed, 266 insertions(+), 72 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/ec8973d1/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java
index b94601c..d1d81c8 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java
@@ -28,13 +28,11 @@ import org.apache.spark.SparkEnv;
 import org.apache.spark.TaskContext;
 import org.apache.spark.sql.AbstractScalaRowIterator;
 import org.apache.spark.sql.catalyst.InternalRow;
-import org.apache.spark.sql.catalyst.expressions.ObjectUnsafeColumnWriter;
 import org.apache.spark.sql.catalyst.expressions.UnsafeColumnWriter;
+import org.apache.spark.sql.catalyst.expressions.UnsafeProjection;
 import org.apache.spark.sql.catalyst.expressions.UnsafeRow;
-import org.apache.spark.sql.catalyst.expressions.UnsafeRowConverter;
 import org.apache.spark.sql.catalyst.util.ObjectPool;
-import org.apache.spark.sql.types.StructField;
-import org.apache.spark.sql.types.StructType;
+import org.apache.spark.sql.types.*;
 import org.apache.spark.unsafe.PlatformDependent;
 import org.apache.spark.util.collection.unsafe.sort.PrefixComparator;
 import org.apache.spark.util.collection.unsafe.sort.RecordComparator;
@@ -52,10 +50,9 @@ final class UnsafeExternalRowSorter {
   private long numRowsInserted = 0;
 
   private final StructType schema;
-  private final UnsafeRowConverter rowConverter;
+  private final UnsafeProjection unsafeProjection;
   private final PrefixComputer prefixComputer;
   private final UnsafeExternalSorter sorter;
-  private byte[] rowConversionBuffer = new byte[1024 * 8];
 
   public static abstract class PrefixComputer {
     abstract long computePrefix(InternalRow row);
@@ -67,7 +64,7 @@ final class UnsafeExternalRowSorter {
       PrefixComparator prefixComparator,
       PrefixComputer prefixComputer) throws IOException {
     this.schema = schema;
-    this.rowConverter = new UnsafeRowConverter(schema);
+    this.unsafeProjection = UnsafeProjection.create(schema);
     this.prefixComputer = prefixComputer;
     final SparkEnv sparkEnv = SparkEnv.get();
     final TaskContext taskContext = TaskContext.get();
@@ -94,18 +91,12 @@ final class UnsafeExternalRowSorter {
 
   @VisibleForTesting
   void insertRow(InternalRow row) throws IOException {
-    final int sizeRequirement = rowConverter.getSizeRequirement(row);
-    if (sizeRequirement > rowConversionBuffer.length) {
-      rowConversionBuffer = new byte[sizeRequirement];
-    }
-    final int bytesWritten = rowConverter.writeRow(
-      row, rowConversionBuffer, PlatformDependent.BYTE_ARRAY_OFFSET, sizeRequirement, null);
-    assert (bytesWritten == sizeRequirement);
+    UnsafeRow unsafeRow = unsafeProjection.apply(row);
     final long prefix = prefixComputer.computePrefix(row);
     sorter.insertRecord(
-      rowConversionBuffer,
-      PlatformDependent.BYTE_ARRAY_OFFSET,
-      sizeRequirement,
+      unsafeRow.getBaseObject(),
+      unsafeRow.getBaseOffset(),
+      unsafeRow.getSizeInBytes(),
       prefix
     );
     numRowsInserted++;
@@ -186,7 +177,7 @@ final class UnsafeExternalRowSorter {
   public static boolean supportsSchema(StructType schema) {
     // TODO: add spilling note to explain why we do this for now:
     for (StructField field : schema.fields()) {
-      if (UnsafeColumnWriter.forType(field.dataType()) instanceof ObjectUnsafeColumnWriter) {
+      if (!UnsafeColumnWriter.canEmbed(field.dataType())) {
         return false;
       }
     }

http://git-wip-us.apache.org/repos/asf/spark/blob/ec8973d1/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
index 65ae87f..692b9fd 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
@@ -424,20 +424,20 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
 
       case (BinaryType, StringType) =>
         defineCodeGen (ctx, ev, c =>
-          s"${ctx.stringType}.fromBytes($c)")
+          s"UTF8String.fromBytes($c)")
 
       case (DateType, StringType) =>
         defineCodeGen(ctx, ev, c =>
-          s"""${ctx.stringType}.fromString(
+          s"""UTF8String.fromString(
                 org.apache.spark.sql.catalyst.util.DateTimeUtils.dateToString($c))""")
 
       case (TimestampType, StringType) =>
         defineCodeGen(ctx, ev, c =>
-          s"""${ctx.stringType}.fromString(
+          s"""UTF8String.fromString(
                 org.apache.spark.sql.catalyst.util.DateTimeUtils.timestampToString($c))""")
 
       case (_, StringType) =>
-        defineCodeGen(ctx, ev, c => s"${ctx.stringType}.fromString(String.valueOf($c))")
+        defineCodeGen(ctx, ev, c => s"UTF8String.fromString(String.valueOf($c))")
 
       case (StringType, IntervalType) =>
         defineCodeGen(ctx, ev, c =>

http://git-wip-us.apache.org/repos/asf/spark/blob/ec8973d1/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
index bf47a6c..24b01ea 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
@@ -18,6 +18,8 @@
 package org.apache.spark.sql.catalyst.expressions
 
 import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateUnsafeProjection, GenerateMutableProjection}
+import org.apache.spark.sql.types.{StructType, DataType}
 
 /**
  * A [[Projection]] that is calculated by calling the `eval` of each of the specified expressions.
@@ -74,6 +76,39 @@ case class InterpretedMutableProjection(expressions: Seq[Expression]) extends Mu
 }
 
 /**
+ * A projection that returns UnsafeRow.
+ */
+abstract class UnsafeProjection extends Projection {
+  override def apply(row: InternalRow): UnsafeRow
+}
+
+object UnsafeProjection {
+  def create(schema: StructType): UnsafeProjection = create(schema.fields.map(_.dataType))
+
+  def create(fields: Seq[DataType]): UnsafeProjection = {
+    val exprs = fields.zipWithIndex.map(x => new BoundReference(x._2, x._1, true))
+    GenerateUnsafeProjection.generate(exprs)
+  }
+}
+
+/**
+ * A projection that could turn UnsafeRow into GenericInternalRow
+ */
+case class FromUnsafeProjection(fields: Seq[DataType]) extends Projection {
+
+  private[this] val expressions = fields.zipWithIndex.map { case (dt, idx) =>
+    new BoundReference(idx, dt, true)
+  }
+
+  @transient private[this] lazy val generatedProj =
+    GenerateMutableProjection.generate(expressions)()
+
+  override def apply(input: InternalRow): InternalRow = {
+    generatedProj(input)
+  }
+}
+
+/**
  * A mutable wrapper that makes two rows appear as a single concatenated row.  Designed to
  * be instantiated once per thread and reused.
  */

http://git-wip-us.apache.org/repos/asf/spark/blob/ec8973d1/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala
index 6af5e62..885ab09 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala
@@ -147,77 +147,73 @@ private object UnsafeColumnWriter {
       case t => ObjectUnsafeColumnWriter
     }
   }
+
+  /**
+   * Returns whether the dataType can be embedded into UnsafeRow (not using ObjectPool).
+   */
+  def canEmbed(dataType: DataType): Boolean = {
+    forType(dataType) != ObjectUnsafeColumnWriter
+  }
 }
 
 // ------------------------------------------------------------------------------------------------
 
-private object NullUnsafeColumnWriter extends NullUnsafeColumnWriter
-private object BooleanUnsafeColumnWriter extends BooleanUnsafeColumnWriter
-private object ByteUnsafeColumnWriter extends ByteUnsafeColumnWriter
-private object ShortUnsafeColumnWriter extends ShortUnsafeColumnWriter
-private object IntUnsafeColumnWriter extends IntUnsafeColumnWriter
-private object LongUnsafeColumnWriter extends LongUnsafeColumnWriter
-private object FloatUnsafeColumnWriter extends FloatUnsafeColumnWriter
-private object DoubleUnsafeColumnWriter extends DoubleUnsafeColumnWriter
-private object StringUnsafeColumnWriter extends StringUnsafeColumnWriter
-private object BinaryUnsafeColumnWriter extends BinaryUnsafeColumnWriter
-private object ObjectUnsafeColumnWriter extends ObjectUnsafeColumnWriter
 
 private abstract class PrimitiveUnsafeColumnWriter extends UnsafeColumnWriter {
   // Primitives don't write to the variable-length region:
   def getSize(sourceRow: InternalRow, column: Int): Int = 0
 }
 
-private class NullUnsafeColumnWriter private() extends PrimitiveUnsafeColumnWriter {
+private object NullUnsafeColumnWriter extends PrimitiveUnsafeColumnWriter {
   override def write(source: InternalRow, target: UnsafeRow, column: Int, cursor: Int): Int = {
     target.setNullAt(column)
     0
   }
 }
 
-private class BooleanUnsafeColumnWriter private() extends PrimitiveUnsafeColumnWriter {
+private object BooleanUnsafeColumnWriter extends PrimitiveUnsafeColumnWriter {
   override def write(source: InternalRow, target: UnsafeRow, column: Int, cursor: Int): Int = {
     target.setBoolean(column, source.getBoolean(column))
     0
   }
 }
 
-private class ByteUnsafeColumnWriter private() extends PrimitiveUnsafeColumnWriter {
+private object ByteUnsafeColumnWriter extends PrimitiveUnsafeColumnWriter {
   override def write(source: InternalRow, target: UnsafeRow, column: Int, cursor: Int): Int = {
     target.setByte(column, source.getByte(column))
     0
   }
 }
 
-private class ShortUnsafeColumnWriter private() extends PrimitiveUnsafeColumnWriter {
+private object ShortUnsafeColumnWriter extends PrimitiveUnsafeColumnWriter {
   override def write(source: InternalRow, target: UnsafeRow, column: Int, cursor: Int): Int = {
     target.setShort(column, source.getShort(column))
     0
   }
 }
 
-private class IntUnsafeColumnWriter private() extends PrimitiveUnsafeColumnWriter {
+private object IntUnsafeColumnWriter extends PrimitiveUnsafeColumnWriter {
   override def write(source: InternalRow, target: UnsafeRow, column: Int, cursor: Int): Int = {
     target.setInt(column, source.getInt(column))
     0
   }
 }
 
-private class LongUnsafeColumnWriter private() extends PrimitiveUnsafeColumnWriter {
+private object LongUnsafeColumnWriter extends PrimitiveUnsafeColumnWriter {
   override def write(source: InternalRow, target: UnsafeRow, column: Int, cursor: Int): Int = {
     target.setLong(column, source.getLong(column))
     0
   }
 }
 
-private class FloatUnsafeColumnWriter private() extends PrimitiveUnsafeColumnWriter {
+private object FloatUnsafeColumnWriter extends PrimitiveUnsafeColumnWriter {
   override def write(source: InternalRow, target: UnsafeRow, column: Int, cursor: Int): Int = {
     target.setFloat(column, source.getFloat(column))
     0
   }
 }
 
-private class DoubleUnsafeColumnWriter private() extends PrimitiveUnsafeColumnWriter {
+private object DoubleUnsafeColumnWriter extends PrimitiveUnsafeColumnWriter {
   override def write(source: InternalRow, target: UnsafeRow, column: Int, cursor: Int): Int = {
     target.setDouble(column, source.getDouble(column))
     0
@@ -226,18 +222,21 @@ private class DoubleUnsafeColumnWriter private() extends PrimitiveUnsafeColumnWr
 
 private abstract class BytesUnsafeColumnWriter extends UnsafeColumnWriter {
 
-  def getBytes(source: InternalRow, column: Int): Array[Byte]
+  protected[this] def isString: Boolean
+  protected[this] def getBytes(source: InternalRow, column: Int): Array[Byte]
 
-  def getSize(source: InternalRow, column: Int): Int = {
+  override def getSize(source: InternalRow, column: Int): Int = {
     val numBytes = getBytes(source, column).length
     ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes)
   }
 
-  protected[this] def isString: Boolean
-
   override def write(source: InternalRow, target: UnsafeRow, column: Int, cursor: Int): Int = {
-    val offset = target.getBaseOffset + cursor
     val bytes = getBytes(source, column)
+    write(target, bytes, column, cursor)
+  }
+
+  def write(target: UnsafeRow, bytes: Array[Byte], column: Int, cursor: Int): Int = {
+    val offset = target.getBaseOffset + cursor
     val numBytes = bytes.length
     if ((numBytes & 0x07) > 0) {
       // zero-out the padding bytes
@@ -256,22 +255,32 @@ private abstract class BytesUnsafeColumnWriter extends UnsafeColumnWriter {
   }
 }
 
-private class StringUnsafeColumnWriter private() extends BytesUnsafeColumnWriter {
+private object StringUnsafeColumnWriter extends BytesUnsafeColumnWriter {
   protected[this] def isString: Boolean = true
   def getBytes(source: InternalRow, column: Int): Array[Byte] = {
     source.getAs[UTF8String](column).getBytes
   }
+  // TODO(davies): refactor this
+  // specialized for codegen
+  def getSize(value: UTF8String): Int =
+    ByteArrayMethods.roundNumberOfBytesToNearestWord(value.numBytes())
+  def write(target: UnsafeRow, value: UTF8String, column: Int, cursor: Int): Int = {
+    write(target, value.getBytes, column, cursor)
+  }
 }
 
-private class BinaryUnsafeColumnWriter private() extends BytesUnsafeColumnWriter {
-  protected[this] def isString: Boolean = false
-  def getBytes(source: InternalRow, column: Int): Array[Byte] = {
+private object BinaryUnsafeColumnWriter extends BytesUnsafeColumnWriter {
+  protected[this] override def isString: Boolean = false
+  override def getBytes(source: InternalRow, column: Int): Array[Byte] = {
     source.getAs[Array[Byte]](column)
   }
+  // specialized for codegen
+  def getSize(value: Array[Byte]): Int =
+    ByteArrayMethods.roundNumberOfBytesToNearestWord(value.length)
 }
 
-private class ObjectUnsafeColumnWriter private() extends UnsafeColumnWriter {
-  def getSize(sourceRow: InternalRow, column: Int): Int = 0
+private object ObjectUnsafeColumnWriter extends UnsafeColumnWriter {
+  override def getSize(sourceRow: InternalRow, column: Int): Int = 0
   override def write(source: InternalRow, target: UnsafeRow, column: Int, cursor: Int): Int = {
     val obj = source.get(column)
     val idx = target.getPool.put(obj)

http://git-wip-us.apache.org/repos/asf/spark/blob/ec8973d1/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
index 328d635..45dc146 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
@@ -24,6 +24,7 @@ import com.google.common.cache.{CacheBuilder, CacheLoader}
 import org.codehaus.janino.ClassBodyEvaluator
 
 import org.apache.spark.Logging
+import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.types._
 import org.apache.spark.unsafe.types.UTF8String
@@ -68,9 +69,6 @@ class CodeGenContext {
     mutableStates += ((javaType, variableName, initialValue))
   }
 
-  val stringType: String = classOf[UTF8String].getName
-  val decimalType: String = classOf[Decimal].getName
-
   final val JAVA_BOOLEAN = "boolean"
   final val JAVA_BYTE = "byte"
   final val JAVA_SHORT = "short"
@@ -136,9 +134,9 @@ class CodeGenContext {
     case LongType | TimestampType => JAVA_LONG
     case FloatType => JAVA_FLOAT
     case DoubleType => JAVA_DOUBLE
-    case dt: DecimalType => decimalType
+    case dt: DecimalType => "Decimal"
     case BinaryType => "byte[]"
-    case StringType => stringType
+    case StringType => "UTF8String"
     case _: StructType => "InternalRow"
     case _: ArrayType => s"scala.collection.Seq"
     case _: MapType => s"scala.collection.Map"
@@ -262,7 +260,12 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin
   private[this] def doCompile(code: String): GeneratedClass = {
     val evaluator = new ClassBodyEvaluator()
     evaluator.setParentClassLoader(getClass.getClassLoader)
-    evaluator.setDefaultImports(Array("org.apache.spark.sql.catalyst.InternalRow"))
+    evaluator.setDefaultImports(Array(
+      classOf[InternalRow].getName,
+      classOf[UnsafeRow].getName,
+      classOf[UTF8String].getName,
+      classOf[Decimal].getName
+    ))
     evaluator.setExtendedClass(classOf[GeneratedClass])
     try {
       evaluator.cook(code)

http://git-wip-us.apache.org/repos/asf/spark/blob/ec8973d1/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala
index 3e5ca30..8f9fcbf 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala
@@ -23,7 +23,7 @@ import org.apache.spark.sql.types._
 /**
  * Java can not access Projection (in package object)
  */
-abstract class BaseProject extends Projection {}
+abstract class BaseProjection extends Projection {}
 
 /**
  * Generates bytecode that produces a new [[InternalRow]] object based on a fixed set of input
@@ -160,7 +160,7 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] {
       return new SpecificProjection(expr);
     }
 
-    class SpecificProjection extends ${classOf[BaseProject].getName} {
+    class SpecificProjection extends ${classOf[BaseProjection].getName} {
       private $exprType[] expressions = null;
       $mutableStates
 

http://git-wip-us.apache.org/repos/asf/spark/blob/ec8973d1/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
new file mode 100644
index 0000000..a81d545
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
@@ -0,0 +1,125 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions.codegen
+
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.types.{NullType, BinaryType, StringType}
+
+
+/**
+ * Generates a [[Projection]] that returns an [[UnsafeRow]].
+ *
+ * It generates the code for all the expressions, compute the total length for all the columns
+ * (can be accessed via variables), and then copy the data into a scratch buffer space in the
+ * form of UnsafeRow (the scratch buffer will grow as needed).
+ *
+ * Note: The returned UnsafeRow will be pointed to a scratch buffer inside the projection.
+ */
+object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafeProjection] {
+
+  protected def canonicalize(in: Seq[Expression]): Seq[Expression] =
+    in.map(ExpressionCanonicalizer.execute)
+
+  protected def bind(in: Seq[Expression], inputSchema: Seq[Attribute]): Seq[Expression] =
+    in.map(BindReferences.bindReference(_, inputSchema))
+
+  protected def create(expressions: Seq[Expression]): UnsafeProjection = {
+    val ctx = newCodeGenContext()
+    val exprs = expressions.map(_.gen(ctx))
+    val allExprs = exprs.map(_.code).mkString("\n")
+    val fixedSize = 8 * exprs.length + UnsafeRow.calculateBitSetWidthInBytes(exprs.length)
+    val stringWriter = "org.apache.spark.sql.catalyst.expressions.StringUnsafeColumnWriter"
+    val binaryWriter = "org.apache.spark.sql.catalyst.expressions.BinaryUnsafeColumnWriter"
+    val additionalSize = expressions.zipWithIndex.map { case (e, i) =>
+      e.dataType match {
+        case StringType =>
+          s" + (${exprs(i).isNull} ? 0 : $stringWriter.getSize(${exprs(i).primitive}))"
+        case BinaryType =>
+          s" + (${exprs(i).isNull} ? 0 : $binaryWriter.getSize(${exprs(i).primitive}))"
+        case _ => ""
+      }
+    }.mkString("")
+
+    val writers = expressions.zipWithIndex.map { case (e, i) =>
+      val update = e.dataType match {
+        case dt if ctx.isPrimitiveType(dt) =>
+          s"${ctx.setColumn("target", dt, i, exprs(i).primitive)}"
+        case StringType =>
+          s"cursor += $stringWriter.write(target, ${exprs(i).primitive}, $i, cursor)"
+        case BinaryType =>
+          s"cursor += $binaryWriter.write(target, ${exprs(i).primitive}, $i, cursor)"
+        case NullType => ""
+        case _ =>
+          throw new UnsupportedOperationException(s"Not supported DataType: ${e.dataType}")
+      }
+      s"""if (${exprs(i).isNull}) {
+            target.setNullAt($i);
+          } else {
+            $update;
+          }"""
+    }.mkString("\n          ")
+
+    val mutableStates = ctx.mutableStates.map { case (javaType, variableName, initialValue) =>
+      s"private $javaType $variableName = $initialValue;"
+    }.mkString("\n      ")
+
+    val code = s"""
+    private $exprType[] expressions;
+
+    public Object generate($exprType[] expr) {
+      this.expressions = expr;
+      return new SpecificProjection();
+    }
+
+    class SpecificProjection extends ${classOf[UnsafeProjection].getName} {
+
+      private UnsafeRow target = new UnsafeRow();
+      private byte[] buffer = new byte[64];
+
+      $mutableStates
+
+      public SpecificProjection() {}
+
+      // Scala.Function1 need this
+      public Object apply(Object row) {
+        return apply((InternalRow) row);
+      }
+
+      public UnsafeRow apply(InternalRow i) {
+        ${allExprs}
+
+        // additionalSize had '+' in the beginning
+        int numBytes = $fixedSize $additionalSize;
+        if (numBytes > buffer.length) {
+          buffer = new byte[numBytes];
+        }
+        target.pointTo(buffer, org.apache.spark.unsafe.PlatformDependent.BYTE_ARRAY_OFFSET,
+          ${expressions.size}, numBytes, null);
+        int cursor = $fixedSize;
+        $writers
+        return target;
+      }
+    }
+    """
+
+    logDebug(s"code for ${expressions.mkString(",")}:\n$code")
+
+    val c = compile(code)
+    c.generate(ctx.references.toArray).asInstanceOf[UnsafeProjection]
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/ec8973d1/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala
index 2fa74b4..b9d4736 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala
@@ -54,7 +54,7 @@ case class MakeDecimal(child: Expression, precision: Int, scale: Int) extends Un
   override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
     nullSafeCodeGen(ctx, ev, eval => {
       s"""
-        ${ev.primitive} = (new ${ctx.decimalType}()).setOrNull($eval, $precision, $scale);
+        ${ev.primitive} = (new Decimal()).setOrNull($eval, $precision, $scale);
         ${ev.isNull} = ${ev.primitive} == null;
       """
     })

http://git-wip-us.apache.org/repos/asf/spark/blob/ec8973d1/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala
index a7ad452..84b289c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala
@@ -263,7 +263,7 @@ case class Bin(child: Expression)
 
   override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
     defineCodeGen(ctx, ev, (c) =>
-      s"${ctx.stringType}.fromString(java.lang.Long.toBinaryString($c))")
+      s"UTF8String.fromString(java.lang.Long.toBinaryString($c))")
   }
 }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/ec8973d1/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala
index a269ec4..8d8d66d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala
@@ -17,12 +17,11 @@
 
 package org.apache.spark.sql.catalyst.expressions
 
-import java.security.MessageDigest
-import java.security.NoSuchAlgorithmException
+import java.security.{MessageDigest, NoSuchAlgorithmException}
 import java.util.zip.CRC32
 
 import org.apache.commons.codec.digest.DigestUtils
-import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
+
 import org.apache.spark.sql.catalyst.expressions.codegen._
 import org.apache.spark.sql.types._
 import org.apache.spark.unsafe.types.UTF8String
@@ -42,7 +41,7 @@ case class Md5(child: Expression) extends UnaryExpression with ImplicitCastInput
 
   override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
     defineCodeGen(ctx, ev, c =>
-      s"${ctx.stringType}.fromString(org.apache.commons.codec.digest.DigestUtils.md5Hex($c))")
+      s"UTF8String.fromString(org.apache.commons.codec.digest.DigestUtils.md5Hex($c))")
   }
 }
 
@@ -93,19 +92,19 @@ case class Sha2(left: Expression, right: Expression)
           try {
             java.security.MessageDigest md = java.security.MessageDigest.getInstance("SHA-224");
             md.update($eval1);
-            ${ev.primitive} = ${ctx.stringType}.fromBytes(md.digest());
+            ${ev.primitive} = UTF8String.fromBytes(md.digest());
           } catch (java.security.NoSuchAlgorithmException e) {
             ${ev.isNull} = true;
           }
         } else if ($eval2 == 256 || $eval2 == 0) {
           ${ev.primitive} =
-            ${ctx.stringType}.fromString($digestUtils.sha256Hex($eval1));
+            UTF8String.fromString($digestUtils.sha256Hex($eval1));
         } else if ($eval2 == 384) {
           ${ev.primitive} =
-            ${ctx.stringType}.fromString($digestUtils.sha384Hex($eval1));
+            UTF8String.fromString($digestUtils.sha384Hex($eval1));
         } else if ($eval2 == 512) {
           ${ev.primitive} =
-            ${ctx.stringType}.fromString($digestUtils.sha512Hex($eval1));
+            UTF8String.fromString($digestUtils.sha512Hex($eval1));
         } else {
           ${ev.isNull} = true;
         }
@@ -129,7 +128,7 @@ case class Sha1(child: Expression) extends UnaryExpression with ImplicitCastInpu
 
   override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
     defineCodeGen(ctx, ev, c =>
-      s"${ctx.stringType}.fromString(org.apache.commons.codec.digest.DigestUtils.shaHex($c))"
+      s"UTF8String.fromString(org.apache.commons.codec.digest.DigestUtils.shaHex($c))"
     )
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/ec8973d1/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala
index 43392df..c43486b 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala
@@ -23,7 +23,7 @@ import org.scalatest.Matchers._
 import org.apache.spark.SparkFunSuite
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.CatalystTypeConverters
-import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateProjection, GenerateMutableProjection}
+import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateUnsafeProjection, GenerateProjection, GenerateMutableProjection}
 import org.apache.spark.sql.catalyst.optimizer.DefaultOptimizer
 import org.apache.spark.sql.catalyst.plans.logical.{OneRowRelation, Project}
 
@@ -43,6 +43,9 @@ trait ExpressionEvalHelper {
     checkEvaluationWithoutCodegen(expression, catalystValue, inputRow)
     checkEvaluationWithGeneratedMutableProjection(expression, catalystValue, inputRow)
     checkEvaluationWithGeneratedProjection(expression, catalystValue, inputRow)
+    if (UnsafeColumnWriter.canEmbed(expression.dataType)) {
+      checkEvalutionWithUnsafeProjection(expression, catalystValue, inputRow)
+    }
     checkEvaluationWithOptimization(expression, catalystValue, inputRow)
   }
 
@@ -142,6 +145,35 @@ trait ExpressionEvalHelper {
     }
   }
 
+  protected def checkEvalutionWithUnsafeProjection(
+      expression: Expression,
+      expected: Any,
+      inputRow: InternalRow = EmptyRow): Unit = {
+    val ctx = GenerateUnsafeProjection.newCodeGenContext()
+    lazy val evaluated = expression.gen(ctx)
+
+    val plan = try {
+      GenerateUnsafeProjection.generate(Alias(expression, s"Optimized($expression)")() :: Nil)
+    } catch {
+      case e: Throwable =>
+        fail(
+          s"""
+            |Code generation of $expression failed:
+            |${evaluated.code}
+            |$e
+          """.stripMargin)
+    }
+
+    val unsafeRow = plan(inputRow)
+    // UnsafeRow cannot be compared with GenericInternalRow directly
+    val actual = FromUnsafeProjection(expression.dataType :: Nil)(unsafeRow)
+    val expectedRow = InternalRow(expected)
+    if (actual != expectedRow) {
+      val input = if (inputRow == EmptyRow) "" else s", input: $inputRow"
+      fail(s"Incorrect Evaluation: $expression, actual: $actual, expected: $expected$input")
+    }
+  }
+
   protected def checkEvaluationWithOptimization(
       expression: Expression,
       expected: Any,


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org