You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2016/04/06 06:09:21 UTC

spark git commit: [SPARK-14296][SQL] whole stage codegen support for Dataset.map

Repository: spark
Updated Branches:
  refs/heads/master 8e5c1cbf2 -> f6456fa80


[SPARK-14296][SQL] whole stage codegen support for Dataset.map

## What changes were proposed in this pull request?

This PR adds a new operator `MapElements` for `Dataset.map`, it's a 1-1 mapping and is easier to adapt to whole stage codegen framework.

## How was this patch tested?

new test in `WholeStageCodegenSuite`

Author: Wenchen Fan <we...@databricks.com>

Closes #12087 from cloud-fan/map.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/f6456fa8
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/f6456fa8
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/f6456fa8

Branch: refs/heads/master
Commit: f6456fa80ba442bfd7ce069fc23b7dbd993e6cb9
Parents: 8e5c1cb
Author: Wenchen Fan <we...@databricks.com>
Authored: Wed Apr 6 12:09:10 2016 +0800
Committer: Wenchen Fan <we...@databricks.com>
Committed: Wed Apr 6 12:09:10 2016 +0800

----------------------------------------------------------------------
 .../sql/catalyst/analysis/unresolved.scala      |  2 +-
 .../sql/catalyst/expressions/objects.scala      | 40 +++++----
 .../sql/catalyst/optimizer/Optimizer.scala      |  9 ++
 .../sql/catalyst/plans/logical/object.scala     | 28 ++++++-
 .../scala/org/apache/spark/sql/Dataset.scala    | 22 ++---
 .../spark/sql/execution/SparkStrategies.scala   |  2 +
 .../spark/sql/execution/WholeStageCodegen.scala | 11 ++-
 .../apache/spark/sql/execution/objects.scala    | 69 +++++++++++++++-
 .../org/apache/spark/sql/DatasetBenchmark.scala | 86 ++++++++++++++++++++
 .../scala/org/apache/spark/sql/QueryTest.scala  |  5 +-
 .../sql/execution/WholeStageCodegenSuite.scala  | 14 +++-
 11 files changed, 247 insertions(+), 41 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/f6456fa8/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
index b2f362b..4ec43ab 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
@@ -345,7 +345,7 @@ case class UnresolvedAlias(child: Expression, aliasName: Option[String] = None)
  * @param inputAttributes The input attributes used to resolve deserializer expression, can be empty
  *                        if we want to resolve deserializer by children output.
  */
-case class UnresolvedDeserializer(deserializer: Expression, inputAttributes: Seq[Attribute])
+case class UnresolvedDeserializer(deserializer: Expression, inputAttributes: Seq[Attribute] = Nil)
   extends UnaryExpression with Unevaluable with NonSQLExpression {
   // The input attributes used to resolve deserializer expression must be all resolved.
   require(inputAttributes.forall(_.resolved), "Input attributes must all be resolved.")

http://git-wip-us.apache.org/repos/asf/spark/blob/f6456fa8/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala
index eebd43d..a0490e1 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala
@@ -119,18 +119,18 @@ case class Invoke(
   override def eval(input: InternalRow): Any =
     throw new UnsupportedOperationException("Only code-generated evaluation is supported.")
 
-  lazy val method = targetObject.dataType match {
+  @transient lazy val method = targetObject.dataType match {
     case ObjectType(cls) =>
-      cls
-        .getMethods
-        .find(_.getName == functionName)
-        .getOrElse(sys.error(s"Couldn't find $functionName on $cls"))
-        .getReturnType
-        .getName
-    case _ => ""
+      val m = cls.getMethods.find(_.getName == functionName)
+      if (m.isEmpty) {
+        sys.error(s"Couldn't find $functionName on $cls")
+      } else {
+        m
+      }
+    case _ => None
   }
 
-  lazy val unboxer = (dataType, method) match {
+  lazy val unboxer = (dataType, method.map(_.getReturnType.getName).getOrElse("")) match {
     case (IntegerType, "java.lang.Object") => (s: String) =>
       s"((java.lang.Integer)$s).intValue()"
     case (LongType, "java.lang.Object") => (s: String) =>
@@ -157,21 +157,31 @@ case class Invoke(
     // If the function can return null, we do an extra check to make sure our null bit is still set
     // correctly.
     val objNullCheck = if (ctx.defaultValue(dataType) == "null") {
-      s"${ev.isNull} = ${ev.value} == null;"
+      s"boolean ${ev.isNull} = ${ev.value} == null;"
     } else {
+      ev.isNull = obj.isNull
       ""
     }
 
     val value = unboxer(s"${obj.value}.$functionName($argString)")
 
+    val evaluate = if (method.forall(_.getExceptionTypes.isEmpty)) {
+      s"$javaType ${ev.value} = ${obj.isNull} ? ${ctx.defaultValue(dataType)} : ($javaType) $value;"
+    } else {
+      s"""
+        $javaType ${ev.value} = ${ctx.defaultValue(javaType)};
+        try {
+          ${ev.value} = ${obj.isNull} ? ${ctx.defaultValue(javaType)} : ($javaType) $value;
+        } catch (Exception e) {
+          org.apache.spark.unsafe.Platform.throwException(e);
+        }
+      """
+    }
+
     s"""
       ${obj.code}
       ${argGen.map(_.code).mkString("\n")}
-
-      boolean ${ev.isNull} = ${obj.isNull};
-      $javaType ${ev.value} =
-        ${ev.isNull} ?
-        ${ctx.defaultValue(dataType)} : ($javaType) $value;
+      $evaluate
       $objNullCheck
     """
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/f6456fa8/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index 69b09bc..c085a37 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -136,6 +136,7 @@ object SamplePushDown extends Rule[LogicalPlan] {
  * representation of data item.  For example back to back map operations.
  */
 object EliminateSerialization extends Rule[LogicalPlan] {
+  // TODO: find a more general way to do this optimization.
   def apply(plan: LogicalPlan): LogicalPlan = plan transform {
     case m @ MapPartitions(_, deserializer, _, child: ObjectOperator)
         if !deserializer.isInstanceOf[Attribute] &&
@@ -144,6 +145,14 @@ object EliminateSerialization extends Rule[LogicalPlan] {
       m.copy(
         deserializer = childWithoutSerialization.output.head,
         child = childWithoutSerialization)
+
+    case m @ MapElements(_, deserializer, _, child: ObjectOperator)
+      if !deserializer.isInstanceOf[Attribute] &&
+        deserializer.dataType == child.outputObject.dataType =>
+      val childWithoutSerialization = child.withObjectOutput
+      m.copy(
+        deserializer = childWithoutSerialization.output.head,
+        child = childWithoutSerialization)
   }
 }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/f6456fa8/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala
index 58313c7..ec33a53 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala
@@ -65,7 +65,7 @@ object MapPartitions {
       child: LogicalPlan): MapPartitions = {
     MapPartitions(
       func.asInstanceOf[Iterator[Any] => Iterator[Any]],
-      UnresolvedDeserializer(encoderFor[T].deserializer, Nil),
+      UnresolvedDeserializer(encoderFor[T].deserializer),
       encoderFor[U].namedExpressions,
       child)
   }
@@ -83,6 +83,30 @@ case class MapPartitions(
     serializer: Seq[NamedExpression],
     child: LogicalPlan) extends UnaryNode with ObjectOperator
 
+object MapElements {
+  def apply[T : Encoder, U : Encoder](
+      func: AnyRef,
+      child: LogicalPlan): MapElements = {
+    MapElements(
+      func,
+      UnresolvedDeserializer(encoderFor[T].deserializer),
+      encoderFor[U].namedExpressions,
+      child)
+  }
+}
+
+/**
+ * A relation produced by applying `func` to each element of the `child`.
+ *
+ * @param deserializer used to extract the input to `func` from an input row.
+ * @param serializer use to serialize the output of `func`.
+ */
+case class MapElements(
+    func: AnyRef,
+    deserializer: Expression,
+    serializer: Seq[NamedExpression],
+    child: LogicalPlan) extends UnaryNode with ObjectOperator
+
 /** Factory for constructing new `AppendColumn` nodes. */
 object AppendColumns {
   def apply[T : Encoder, U : Encoder](
@@ -90,7 +114,7 @@ object AppendColumns {
       child: LogicalPlan): AppendColumns = {
     new AppendColumns(
       func.asInstanceOf[Any => Any],
-      UnresolvedDeserializer(encoderFor[T].deserializer, Nil),
+      UnresolvedDeserializer(encoderFor[T].deserializer),
       encoderFor[U].namedExpressions,
       child)
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/f6456fa8/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index f472a50..2854d5f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -766,7 +766,8 @@ class Dataset[T] private[sql](
 
     implicit val tuple2Encoder: Encoder[(T, U)] =
       ExpressionEncoder.tuple(this.unresolvedTEncoder, other.unresolvedTEncoder)
-    withTypedPlan[(T, U)](other, encoderFor[(T, U)]) { (left, right) =>
+
+    withTypedPlan {
       Project(
         leftData :: rightData :: Nil,
         joined.analyzed)
@@ -1900,7 +1901,9 @@ class Dataset[T] private[sql](
    * @since 1.6.0
    */
   @Experimental
-  def map[U : Encoder](func: T => U): Dataset[U] = mapPartitions(_.map(func))
+  def map[U : Encoder](func: T => U): Dataset[U] = withTypedPlan {
+    MapElements[T, U](func, logicalPlan)
+  }
 
   /**
    * :: Experimental ::
@@ -1911,8 +1914,10 @@ class Dataset[T] private[sql](
    * @since 1.6.0
    */
   @Experimental
-  def map[U](func: MapFunction[T, U], encoder: Encoder[U]): Dataset[U] =
-    map(t => func.call(t))(encoder)
+  def map[U](func: MapFunction[T, U], encoder: Encoder[U]): Dataset[U] = {
+    implicit val uEnc = encoder
+    withTypedPlan(MapElements[T, U](func, logicalPlan))
+  }
 
   /**
    * :: Experimental ::
@@ -2412,12 +2417,7 @@ class Dataset[T] private[sql](
   }
 
   /** A convenient function to wrap a logical plan and produce a Dataset. */
-  @inline private def withTypedPlan(logicalPlan: => LogicalPlan): Dataset[T] = {
-    new Dataset[T](sqlContext, logicalPlan, encoder)
+  @inline private def withTypedPlan[U : Encoder](logicalPlan: => LogicalPlan): Dataset[U] = {
+    Dataset(sqlContext, logicalPlan)
   }
-
-  private[sql] def withTypedPlan[R](
-      other: Dataset[_], encoder: Encoder[R])(
-      f: (LogicalPlan, LogicalPlan) => LogicalPlan): Dataset[R] =
-    new Dataset[R](sqlContext, f(logicalPlan, other.logicalPlan), encoder)
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/f6456fa8/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index e52f05a..5f3128d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -341,6 +341,8 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
 
       case logical.MapPartitions(f, in, out, child) =>
         execution.MapPartitions(f, in, out, planLater(child)) :: Nil
+      case logical.MapElements(f, in, out, child) =>
+        execution.MapElements(f, in, out, planLater(child)) :: Nil
       case logical.AppendColumns(f, in, out, child) =>
         execution.AppendColumns(f, in, out, planLater(child)) :: Nil
       case logical.MapGroups(f, key, in, out, grouping, data, child) =>

http://git-wip-us.apache.org/repos/asf/spark/blob/f6456fa8/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala
index 9f539c4..4e75a3a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala
@@ -152,7 +152,7 @@ trait CodegenSupport extends SparkPlan {
     s"""
        |
        |/*** CONSUME: ${toCommentSafeString(parent.simpleString)} */
-       |${evaluated}
+       |$evaluated
        |${parent.doConsume(ctx, inputVars, rowVar)}
      """.stripMargin
   }
@@ -169,20 +169,20 @@ trait CodegenSupport extends SparkPlan {
 
   /**
    * Returns source code to evaluate the variables for required attributes, and clear the code
-   * of evaluated variables, to prevent them to be evaluated twice..
+   * of evaluated variables, to prevent them to be evaluated twice.
    */
   protected def evaluateRequiredVariables(
       attributes: Seq[Attribute],
       variables: Seq[ExprCode],
       required: AttributeSet): String = {
-    var evaluateVars = ""
+    val evaluateVars = new StringBuilder
     variables.zipWithIndex.foreach { case (ev, i) =>
       if (ev.code != "" && required.contains(attributes(i))) {
-        evaluateVars += ev.code.trim + "\n"
+        evaluateVars.append(ev.code.trim + "\n")
         ev.code = ""
       }
     }
-    evaluateVars
+    evaluateVars.toString()
   }
 
   /**
@@ -305,7 +305,6 @@ case class WholeStageCodegen(child: SparkPlan) extends UnaryNode with CodegenSup
   def doCodeGen(): (CodegenContext, String) = {
     val ctx = new CodegenContext
     val code = child.asInstanceOf[CodegenSupport].produce(ctx, this)
-    val references = ctx.references.toArray
     val source = s"""
       public Object generate(Object[] references) {
         return new GeneratedIterator(references);

http://git-wip-us.apache.org/repos/asf/spark/blob/f6456fa8/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala
index 582dda8..f48f3f0 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala
@@ -17,10 +17,13 @@
 
 package org.apache.spark.sql.execution
 
+import scala.language.existentials
+
+import org.apache.spark.api.java.function.MapFunction
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateSafeProjection, GenerateUnsafeProjection, GenerateUnsafeRowJoiner}
+import org.apache.spark.sql.catalyst.expressions.codegen._
 import org.apache.spark.sql.catalyst.plans.physical._
 import org.apache.spark.sql.types.ObjectType
 
@@ -68,6 +71,70 @@ case class MapPartitions(
 }
 
 /**
+ * Applies the given function to each input row and encodes the result.
+ *
+ * Note that, each serializer expression needs the result object which is returned by the given
+ * function, as input. This operator uses some tricks to make sure we only calculate the result
+ * object once. We don't use [[Project]] directly as subexpression elimination doesn't work with
+ * whole stage codegen and it's confusing to show the un-common-subexpression-eliminated version of
+ * a project while explain.
+ */
+case class MapElements(
+    func: AnyRef,
+    deserializer: Expression,
+    serializer: Seq[NamedExpression],
+    child: SparkPlan) extends UnaryNode with ObjectOperator with CodegenSupport {
+  override def output: Seq[Attribute] = serializer.map(_.toAttribute)
+
+  override def upstreams(): Seq[RDD[InternalRow]] = {
+    child.asInstanceOf[CodegenSupport].upstreams()
+  }
+
+  protected override def doProduce(ctx: CodegenContext): String = {
+    child.asInstanceOf[CodegenSupport].produce(ctx, this)
+  }
+
+  override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = {
+    val (funcClass, methodName) = func match {
+      case m: MapFunction[_, _] => classOf[MapFunction[_, _]] -> "call"
+      case _ => classOf[Any => Any] -> "apply"
+    }
+    val funcObj = Literal.create(func, ObjectType(funcClass))
+    val resultObjType = serializer.head.collect { case b: BoundReference => b }.head.dataType
+    val callFunc = Invoke(funcObj, methodName, resultObjType, Seq(deserializer))
+
+    val bound = ExpressionCanonicalizer.execute(
+      BindReferences.bindReference(callFunc, child.output))
+    ctx.currentVars = input
+    val evaluated = bound.gen(ctx)
+
+    val resultObj = LambdaVariable(evaluated.value, evaluated.isNull, resultObjType)
+    val outputFields = serializer.map(_ transform {
+      case _: BoundReference => resultObj
+    })
+    val resultVars = outputFields.map(_.gen(ctx))
+    s"""
+      ${evaluated.code}
+      ${consume(ctx, resultVars)}
+    """
+  }
+
+  override protected def doExecute(): RDD[InternalRow] = {
+    val callFunc: Any => Any = func match {
+      case m: MapFunction[_, _] => i => m.asInstanceOf[MapFunction[Any, Any]].call(i)
+      case _ => func.asInstanceOf[Any => Any]
+    }
+    child.execute().mapPartitionsInternal { iter =>
+      val getObject = generateToObject(deserializer, child.output)
+      val outputObject = generateToRow(serializer)
+      iter.map(row => outputObject(callFunc(getObject(row))))
+    }
+  }
+
+  override def outputOrdering: Seq[SortOrder] = child.outputOrdering
+}
+
+/**
  * Applies the given function to each input row, appending the encoded result at the end of the row.
  */
 case class AppendColumns(

http://git-wip-us.apache.org/repos/asf/spark/blob/f6456fa8/sql/core/src/test/scala/org/apache/spark/sql/DatasetBenchmark.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetBenchmark.scala
new file mode 100644
index 0000000..6eb9524
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetBenchmark.scala
@@ -0,0 +1,86 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import org.apache.spark.SparkContext
+import org.apache.spark.sql.types.StringType
+import org.apache.spark.util.Benchmark
+
+/**
+ * Benchmark for Dataset typed operations comparing with DataFrame and RDD versions.
+ */
+object DatasetBenchmark {
+
+  case class Data(l: Long, s: String)
+
+  def main(args: Array[String]): Unit = {
+    val sparkContext = new SparkContext("local[*]", "Dataset benchmark")
+    val sqlContext = new SQLContext(sparkContext)
+
+    import sqlContext.implicits._
+
+    val numRows = 10000000
+    val df = sqlContext.range(1, numRows).select($"id".as("l"), $"id".cast(StringType).as("s"))
+    val numChains = 10
+
+    val benchmark = new Benchmark("back-to-back map", numRows)
+
+    val func = (d: Data) => Data(d.l + 1, d.s)
+    benchmark.addCase("Dataset") { iter =>
+      var res = df.as[Data]
+      var i = 0
+      while (i < numChains) {
+        res = res.map(func)
+        i += 1
+      }
+      res.queryExecution.toRdd.foreach(_ => Unit)
+    }
+
+    benchmark.addCase("DataFrame") { iter =>
+      var res = df
+      var i = 0
+      while (i < numChains) {
+        res = res.select($"l" + 1 as "l")
+        i += 1
+      }
+      res.queryExecution.toRdd.foreach(_ => Unit)
+    }
+
+    val rdd = sparkContext.range(1, numRows).map(l => Data(l, l.toString))
+    benchmark.addCase("RDD") { iter =>
+      var res = rdd
+      var i = 0
+      while (i < numChains) {
+        res = rdd.map(func)
+        i += 1
+      }
+      res.foreach(_ => Unit)
+    }
+
+    /*
+    Java HotSpot(TM) 64-Bit Server VM 1.8.0_60-b27 on Mac OS X 10.11.4
+    Intel(R) Core(TM) i7-4960HQ CPU @ 2.60GHz
+    back-to-back map:                   Best/Avg Time(ms)    Rate(M/s)   Per Row(ns)   Relative
+    -------------------------------------------------------------------------------------------
+    Dataset                                   902 /  995         11.1          90.2       1.0X
+    DataFrame                                 132 /  167         75.5          13.2       6.8X
+    RDD                                       216 /  237         46.3          21.6       4.2X
+    */
+    benchmark.run()
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/f6456fa8/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
index f7f3bd7..4e62fac 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
@@ -198,10 +198,7 @@ abstract class QueryTest extends PlanTest {
     val logicalPlan = df.queryExecution.analyzed
     // bypass some cases that we can't handle currently.
     logicalPlan.transform {
-      case _: MapPartitions => return
-      case _: MapGroups => return
-      case _: AppendColumns => return
-      case _: CoGroup => return
+      case _: ObjectOperator => return
       case _: LogicalRelation => return
     }.transformAllExpressions {
       case a: ImperativeAggregate => return

http://git-wip-us.apache.org/repos/asf/spark/blob/f6456fa8/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala
index 6d5be0b..f73ca88 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala
@@ -17,7 +17,8 @@
 
 package org.apache.spark.sql.execution
 
-import org.apache.spark.sql.Row
+import org.apache.spark.api.java.function.MapFunction
+import org.apache.spark.sql.{Encoders, Row}
 import org.apache.spark.sql.execution.aggregate.TungstenAggregate
 import org.apache.spark.sql.execution.joins.BroadcastHashJoin
 import org.apache.spark.sql.functions.{avg, broadcast, col, max}
@@ -70,4 +71,15 @@ class WholeStageCodegenSuite extends SparkPlanTest with SharedSQLContext {
         p.asInstanceOf[WholeStageCodegen].child.isInstanceOf[Sort]).isDefined)
     assert(df.collect() === Array(Row(1), Row(2), Row(3)))
   }
+
+  test("MapElements should be included in WholeStageCodegen") {
+    import testImplicits._
+
+    val ds = sqlContext.range(10).map(_.toString)
+    val plan = ds.queryExecution.executedPlan
+    assert(plan.find(p =>
+      p.isInstanceOf[WholeStageCodegen] &&
+        p.asInstanceOf[WholeStageCodegen].child.isInstanceOf[MapElements]).isDefined)
+    assert(ds.collect() === 0.until(10).map(_.toString).toArray)
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org