You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by yh...@apache.org on 2016/04/08 02:23:38 UTC

spark git commit: [SPARK-14270][SQL] whole stage codegen support for typed filter

Repository: spark
Updated Branches:
  refs/heads/master ae1db91d1 -> 49fb23708


[SPARK-14270][SQL] whole stage codegen support for typed filter

## What changes were proposed in this pull request?

We implement typed filter by `MapPartitions`, which doesn't work well with whole stage codegen. This PR use `Filter` to implement typed filter and we can get the whole stage codegen support for free.

This PR also introduced `DeserializeToObject` and `SerializeFromObject`, to seperate serialization logic from object operator, so that it's eaiser to write optimization rules for adjacent object operators.

## How was this patch tested?

existing tests.

Author: Wenchen Fan <we...@databricks.com>

Closes #12061 from cloud-fan/whole-stage-codegen.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/49fb2370
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/49fb2370
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/49fb2370

Branch: refs/heads/master
Commit: 49fb237081bbca0d811aa48aa06f4728fea62781
Parents: ae1db91
Author: Wenchen Fan <we...@databricks.com>
Authored: Thu Apr 7 17:23:34 2016 -0700
Committer: Yin Huai <yh...@databricks.com>
Committed: Thu Apr 7 17:23:34 2016 -0700

----------------------------------------------------------------------
 .../spark/sql/catalyst/analysis/Analyzer.scala  |  2 +
 .../apache/spark/sql/catalyst/dsl/package.scala | 19 +++++
 .../sql/catalyst/optimizer/Optimizer.scala      | 39 +++++++++-
 .../sql/catalyst/plans/logical/object.scala     | 37 +++++++++-
 .../TypedFilterOptimizationSuite.scala          | 74 +++++++++++++++++++
 .../scala/org/apache/spark/sql/Dataset.scala    | 16 ++++-
 .../spark/sql/execution/SparkStrategies.scala   |  4 ++
 .../apache/spark/sql/execution/objects.scala    | 67 +++++++++++++++++
 .../org/apache/spark/sql/DatasetBenchmark.scala | 76 +++++++++++++++++---
 .../scala/org/apache/spark/sql/QueryTest.scala  |  2 +
 .../sql/execution/WholeStageCodegenSuite.scala  | 21 +++++-
 11 files changed, 342 insertions(+), 15 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/49fb2370/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index 7bcba42..3555a6d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -1670,6 +1670,8 @@ object CleanupAliases extends Rule[LogicalPlan] {
     // Operators that operate on objects should only have expressions from encoders, which should
     // never have extra aliases.
     case o: ObjectOperator => o
+    case d: DeserializeToObject => d
+    case s: SerializeFromObject => s
 
     case other =>
       var stop = false

http://git-wip-us.apache.org/repos/asf/spark/blob/49fb2370/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
index 1059470..1e72966 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
@@ -21,6 +21,7 @@ import java.sql.{Date, Timestamp}
 
 import scala.language.implicitConversions
 
+import org.apache.spark.sql.Encoder
 import org.apache.spark.sql.catalyst.analysis._
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.expressions.aggregate._
@@ -166,6 +167,14 @@ package object dsl {
       case target => UnresolvedStar(Option(target))
     }
 
+    def callFunction[T, U](
+        func: T => U,
+        returnType: DataType,
+        argument: Expression): Expression = {
+      val function = Literal.create(func, ObjectType(classOf[T => U]))
+      Invoke(function, "apply", returnType, argument :: Nil)
+    }
+
     implicit class DslSymbol(sym: Symbol) extends ImplicitAttribute { def s: String = sym.name }
     // TODO more implicit class for literal?
     implicit class DslString(val s: String) extends ImplicitOperators {
@@ -270,6 +279,16 @@ package object dsl {
 
       def where(condition: Expression): LogicalPlan = Filter(condition, logicalPlan)
 
+      def filter[T : Encoder](func: T => Boolean): LogicalPlan = {
+        val deserialized = logicalPlan.deserialize[T]
+        val condition = expressions.callFunction(func, BooleanType, deserialized.output.head)
+        Filter(condition, deserialized).serialize[T]
+      }
+
+      def serialize[T : Encoder]: LogicalPlan = CatalystSerde.serialize[T](logicalPlan)
+
+      def deserialize[T : Encoder]: LogicalPlan = CatalystSerde.deserialize[T](logicalPlan)
+
       def limit(limitExpr: Expression): LogicalPlan = Limit(limitExpr, logicalPlan)
 
       def join(

http://git-wip-us.apache.org/repos/asf/spark/blob/49fb2370/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index f581810..619514e 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -93,6 +93,8 @@ abstract class Optimizer extends RuleExecutor[LogicalPlan] {
       EliminateSerialization) ::
     Batch("Decimal Optimizations", FixedPoint(100),
       DecimalAggregates) ::
+    Batch("Typed Filter Optimization", FixedPoint(100),
+      EmbedSerializerInFilter) ::
     Batch("LocalRelation", FixedPoint(100),
       ConvertToLocalRelation) ::
     Batch("Subquery", Once,
@@ -147,12 +149,18 @@ object EliminateSerialization extends Rule[LogicalPlan] {
         child = childWithoutSerialization)
 
     case m @ MapElements(_, deserializer, _, child: ObjectOperator)
-      if !deserializer.isInstanceOf[Attribute] &&
-        deserializer.dataType == child.outputObject.dataType =>
+        if !deserializer.isInstanceOf[Attribute] &&
+          deserializer.dataType == child.outputObject.dataType =>
       val childWithoutSerialization = child.withObjectOutput
       m.copy(
         deserializer = childWithoutSerialization.output.head,
         child = childWithoutSerialization)
+
+    case d @ DeserializeToObject(_, s: SerializeFromObject)
+        if d.outputObjectType == s.inputObjectType =>
+      // Adds an extra Project here, to preserve the output expr id of `DeserializeToObject`.
+      val objAttr = Alias(s.child.output.head, "obj")(exprId = d.output.head.exprId)
+      Project(objAttr :: Nil, s.child)
   }
 }
 
@@ -1329,3 +1337,30 @@ object ComputeCurrentTime extends Rule[LogicalPlan] {
     }
   }
 }
+
+/**
+ * Typed [[Filter]] is by default surrounded by a [[DeserializeToObject]] beneath it and a
+ * [[SerializeFromObject]] above it.  If these serializations can't be eliminated, we should embed
+ * the deserializer in filter condition to save the extra serialization at last.
+ */
+object EmbedSerializerInFilter extends Rule[LogicalPlan] {
+  def apply(plan: LogicalPlan): LogicalPlan = plan transform {
+    case s @ SerializeFromObject(_, Filter(condition, d: DeserializeToObject)) =>
+      val numObjects = condition.collect {
+        case a: Attribute if a == d.output.head => a
+      }.length
+
+      if (numObjects > 1) {
+        // If the filter condition references the object more than one times, we should not embed
+        // deserializer in it as the deserialization will happen many times and slow down the
+        // execution.
+        // TODO: we can still embed it if we can make sure subexpression elimination works here.
+        s
+      } else {
+        val newCondition = condition transform {
+          case a: Attribute if a == d.output.head => d.deserializer.child
+        }
+        Filter(newCondition, d.child)
+      }
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/49fb2370/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala
index ec33a53..6df4618 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala
@@ -21,7 +21,42 @@ import org.apache.spark.sql.Encoder
 import org.apache.spark.sql.catalyst.analysis.UnresolvedDeserializer
 import org.apache.spark.sql.catalyst.encoders._
 import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.types.{ObjectType, StructType}
+import org.apache.spark.sql.types.{DataType, ObjectType, StructType}
+
+object CatalystSerde {
+  def deserialize[T : Encoder](child: LogicalPlan): DeserializeToObject = {
+    val deserializer = UnresolvedDeserializer(encoderFor[T].deserializer)
+    DeserializeToObject(Alias(deserializer, "obj")(), child)
+  }
+
+  def serialize[T : Encoder](child: LogicalPlan): SerializeFromObject = {
+    SerializeFromObject(encoderFor[T].namedExpressions, child)
+  }
+}
+
+/**
+ * Takes the input row from child and turns it into object using the given deserializer expression.
+ * The output of this operator is a single-field safe row containing the deserialized object.
+ */
+case class DeserializeToObject(
+    deserializer: Alias,
+    child: LogicalPlan) extends UnaryNode {
+  override def output: Seq[Attribute] = deserializer.toAttribute :: Nil
+
+  def outputObjectType: DataType = deserializer.dataType
+}
+
+/**
+ * Takes the input object from child and turns in into unsafe row using the given serializer
+ * expression.  The output of its child must be a single-field row containing the input object.
+ */
+case class SerializeFromObject(
+    serializer: Seq[NamedExpression],
+    child: LogicalPlan) extends UnaryNode {
+  override def output: Seq[Attribute] = serializer.map(_.toAttribute)
+
+  def inputObjectType: DataType = child.output.head.dataType
+}
 
 /**
  * A trait for logical operators that apply user defined functions to domain objects.

http://git-wip-us.apache.org/repos/asf/spark/blob/49fb2370/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/TypedFilterOptimizationSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/TypedFilterOptimizationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/TypedFilterOptimizationSuite.scala
new file mode 100644
index 0000000..1fae64e
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/TypedFilterOptimizationSuite.scala
@@ -0,0 +1,74 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.optimizer
+
+import scala.reflect.runtime.universe.TypeTag
+
+import org.apache.spark.sql.catalyst.analysis.UnresolvedDeserializer
+import org.apache.spark.sql.catalyst.dsl.expressions._
+import org.apache.spark.sql.catalyst.dsl.plans._
+import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder}
+import org.apache.spark.sql.catalyst.plans.PlanTest
+import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
+import org.apache.spark.sql.catalyst.rules.RuleExecutor
+import org.apache.spark.sql.types.BooleanType
+
+class TypedFilterOptimizationSuite extends PlanTest {
+  object Optimize extends RuleExecutor[LogicalPlan] {
+    val batches =
+      Batch("EliminateSerialization", FixedPoint(50),
+        EliminateSerialization) ::
+      Batch("EmbedSerializerInFilter", FixedPoint(50),
+        EmbedSerializerInFilter) :: Nil
+  }
+
+  implicit private def productEncoder[T <: Product : TypeTag] = ExpressionEncoder[T]()
+
+  test("back to back filter") {
+    val input = LocalRelation('_1.int, '_2.int)
+    val f1 = (i: (Int, Int)) => i._1 > 0
+    val f2 = (i: (Int, Int)) => i._2 > 0
+
+    val query = input.filter(f1).filter(f2).analyze
+
+    val optimized = Optimize.execute(query)
+
+    val expected = input.deserialize[(Int, Int)]
+      .where(callFunction(f1, BooleanType, 'obj))
+      .select('obj.as("obj"))
+      .where(callFunction(f2, BooleanType, 'obj))
+      .serialize[(Int, Int)].analyze
+
+    comparePlans(optimized, expected)
+  }
+
+  test("embed deserializer in filter condition if there is only one filter") {
+    val input = LocalRelation('_1.int, '_2.int)
+    val f = (i: (Int, Int)) => i._1 > 0
+
+    val query = input.filter(f).analyze
+
+    val optimized = Optimize.execute(query)
+
+    val deserializer = UnresolvedDeserializer(encoderFor[(Int, Int)].deserializer)
+    val condition = callFunction(f, BooleanType, deserializer)
+    val expected = input.where(condition).analyze
+
+    comparePlans(optimized, expected)
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/49fb2370/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index 2854d5f..2f6d8d1 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -1879,7 +1879,13 @@ class Dataset[T] private[sql](
    * @since 1.6.0
    */
   @Experimental
-  def filter(func: T => Boolean): Dataset[T] = mapPartitions(_.filter(func))
+  def filter(func: T => Boolean): Dataset[T] = {
+    val deserialized = CatalystSerde.deserialize[T](logicalPlan)
+    val function = Literal.create(func, ObjectType(classOf[T => Boolean]))
+    val condition = Invoke(function, "apply", BooleanType, deserialized.output)
+    val filter = Filter(condition, deserialized)
+    withTypedPlan(CatalystSerde.serialize[T](filter))
+  }
 
   /**
    * :: Experimental ::
@@ -1890,7 +1896,13 @@ class Dataset[T] private[sql](
    * @since 1.6.0
    */
   @Experimental
-  def filter(func: FilterFunction[T]): Dataset[T] = filter(t => func.call(t))
+  def filter(func: FilterFunction[T]): Dataset[T] = {
+    val deserialized = CatalystSerde.deserialize[T](logicalPlan)
+    val function = Literal.create(func, ObjectType(classOf[FilterFunction[T]]))
+    val condition = Invoke(function, "call", BooleanType, deserialized.output)
+    val filter = Filter(condition, deserialized)
+    withTypedPlan(CatalystSerde.serialize[T](filter))
+  }
 
   /**
    * :: Experimental ::

http://git-wip-us.apache.org/repos/asf/spark/blob/49fb2370/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index eee2b94..c15aaed 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -346,6 +346,10 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
         throw new IllegalStateException(
           "logical intersect operator should have been replaced by semi-join in the optimizer")
 
+      case logical.DeserializeToObject(deserializer, child) =>
+        execution.DeserializeToObject(deserializer, planLater(child)) :: Nil
+      case logical.SerializeFromObject(serializer, child) =>
+        execution.SerializeFromObject(serializer, planLater(child)) :: Nil
       case logical.MapPartitions(f, in, out, child) =>
         execution.MapPartitions(f, in, out, planLater(child)) :: Nil
       case logical.MapElements(f, in, out, child) =>

http://git-wip-us.apache.org/repos/asf/spark/blob/49fb2370/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala
index f48f3f0..d2ab18e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala
@@ -28,6 +28,73 @@ import org.apache.spark.sql.catalyst.plans.physical._
 import org.apache.spark.sql.types.ObjectType
 
 /**
+ * Takes the input row from child and turns it into object using the given deserializer expression.
+ * The output of this operator is a single-field safe row containing the deserialized object.
+ */
+case class DeserializeToObject(
+    deserializer: Alias,
+    child: SparkPlan) extends UnaryNode with CodegenSupport {
+  override def output: Seq[Attribute] = deserializer.toAttribute :: Nil
+
+  override def upstreams(): Seq[RDD[InternalRow]] = {
+    child.asInstanceOf[CodegenSupport].upstreams()
+  }
+
+  protected override def doProduce(ctx: CodegenContext): String = {
+    child.asInstanceOf[CodegenSupport].produce(ctx, this)
+  }
+
+  override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = {
+    val bound = ExpressionCanonicalizer.execute(
+      BindReferences.bindReference(deserializer, child.output))
+    ctx.currentVars = input
+    val resultVars = bound.gen(ctx) :: Nil
+    consume(ctx, resultVars)
+  }
+
+  override protected def doExecute(): RDD[InternalRow] = {
+    child.execute().mapPartitionsInternal { iter =>
+      val projection = GenerateSafeProjection.generate(deserializer :: Nil, child.output)
+      iter.map(projection)
+    }
+  }
+}
+
+/**
+ * Takes the input object from child and turns in into unsafe row using the given serializer
+ * expression.  The output of its child must be a single-field row containing the input object.
+ */
+case class SerializeFromObject(
+    serializer: Seq[NamedExpression],
+    child: SparkPlan) extends UnaryNode with CodegenSupport {
+  override def output: Seq[Attribute] = serializer.map(_.toAttribute)
+
+  override def upstreams(): Seq[RDD[InternalRow]] = {
+    child.asInstanceOf[CodegenSupport].upstreams()
+  }
+
+  protected override def doProduce(ctx: CodegenContext): String = {
+    child.asInstanceOf[CodegenSupport].produce(ctx, this)
+  }
+
+  override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = {
+    val bound = serializer.map { expr =>
+      ExpressionCanonicalizer.execute(BindReferences.bindReference(expr, child.output))
+    }
+    ctx.currentVars = input
+    val resultVars = bound.map(_.gen(ctx))
+    consume(ctx, resultVars)
+  }
+
+  override protected def doExecute(): RDD[InternalRow] = {
+    child.execute().mapPartitionsInternal { iter =>
+      val projection = UnsafeProjection.create(serializer)
+      iter.map(projection)
+    }
+  }
+}
+
+/**
  * Helper functions for physical operators that work with user defined objects.
  */
 trait ObjectOperator extends SparkPlan {

http://git-wip-us.apache.org/repos/asf/spark/blob/49fb2370/sql/core/src/test/scala/org/apache/spark/sql/DatasetBenchmark.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetBenchmark.scala
index 6eb9524..5f3dd90 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetBenchmark.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetBenchmark.scala
@@ -28,16 +28,10 @@ object DatasetBenchmark {
 
   case class Data(l: Long, s: String)
 
-  def main(args: Array[String]): Unit = {
-    val sparkContext = new SparkContext("local[*]", "Dataset benchmark")
-    val sqlContext = new SQLContext(sparkContext)
-
+  def backToBackMap(sqlContext: SQLContext, numRows: Long, numChains: Int): Benchmark = {
     import sqlContext.implicits._
 
-    val numRows = 10000000
     val df = sqlContext.range(1, numRows).select($"id".as("l"), $"id".cast(StringType).as("s"))
-    val numChains = 10
-
     val benchmark = new Benchmark("back-to-back map", numRows)
 
     val func = (d: Data) => Data(d.l + 1, d.s)
@@ -61,7 +55,7 @@ object DatasetBenchmark {
       res.queryExecution.toRdd.foreach(_ => Unit)
     }
 
-    val rdd = sparkContext.range(1, numRows).map(l => Data(l, l.toString))
+    val rdd = sqlContext.sparkContext.range(1, numRows).map(l => Data(l, l.toString))
     benchmark.addCase("RDD") { iter =>
       var res = rdd
       var i = 0
@@ -72,6 +66,63 @@ object DatasetBenchmark {
       res.foreach(_ => Unit)
     }
 
+    benchmark
+  }
+
+  def backToBackFilter(sqlContext: SQLContext, numRows: Long, numChains: Int): Benchmark = {
+    import sqlContext.implicits._
+
+    val df = sqlContext.range(1, numRows).select($"id".as("l"), $"id".cast(StringType).as("s"))
+    val benchmark = new Benchmark("back-to-back filter", numRows)
+
+    val func = (d: Data, i: Int) => d.l % (100L + i) == 0L
+    val funcs = 0.until(numChains).map { i =>
+      (d: Data) => func(d, i)
+    }
+    benchmark.addCase("Dataset") { iter =>
+      var res = df.as[Data]
+      var i = 0
+      while (i < numChains) {
+        res = res.filter(funcs(i))
+        i += 1
+      }
+      res.queryExecution.toRdd.foreach(_ => Unit)
+    }
+
+    benchmark.addCase("DataFrame") { iter =>
+      var res = df
+      var i = 0
+      while (i < numChains) {
+        res = res.filter($"l" % (100L + i) === 0L)
+        i += 1
+      }
+      res.queryExecution.toRdd.foreach(_ => Unit)
+    }
+
+    val rdd = sqlContext.sparkContext.range(1, numRows).map(l => Data(l, l.toString))
+    benchmark.addCase("RDD") { iter =>
+      var res = rdd
+      var i = 0
+      while (i < numChains) {
+        res = rdd.filter(funcs(i))
+        i += 1
+      }
+      res.foreach(_ => Unit)
+    }
+
+    benchmark
+  }
+
+  def main(args: Array[String]): Unit = {
+    val sparkContext = new SparkContext("local[*]", "Dataset benchmark")
+    val sqlContext = new SQLContext(sparkContext)
+
+    val numRows = 10000000
+    val numChains = 10
+
+    val benchmark = backToBackMap(sqlContext, numRows, numChains)
+    val benchmark2 = backToBackFilter(sqlContext, numRows, numChains)
+
     /*
     Java HotSpot(TM) 64-Bit Server VM 1.8.0_60-b27 on Mac OS X 10.11.4
     Intel(R) Core(TM) i7-4960HQ CPU @ 2.60GHz
@@ -82,5 +133,14 @@ object DatasetBenchmark {
     RDD                                       216 /  237         46.3          21.6       4.2X
     */
     benchmark.run()
+
+    /*
+    back-to-back filter:                Best/Avg Time(ms)    Rate(M/s)   Per Row(ns)   Relative
+    -------------------------------------------------------------------------------------------
+    Dataset                                   585 /  628         17.1          58.5       1.0X
+    DataFrame                                  62 /   80        160.7           6.2       9.4X
+    RDD                                       205 /  220         48.7          20.5       2.8X
+    */
+    benchmark2.run()
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/49fb2370/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
index 48a077d..8268628 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
@@ -32,6 +32,7 @@ import org.apache.spark.sql.execution.LogicalRDD
 import org.apache.spark.sql.execution.columnar.InMemoryRelation
 import org.apache.spark.sql.execution.datasources.LogicalRelation
 import org.apache.spark.sql.execution.streaming.MemoryPlan
+import org.apache.spark.sql.types.ObjectType
 
 abstract class QueryTest extends PlanTest {
 
@@ -204,6 +205,7 @@ abstract class QueryTest extends PlanTest {
       case _: MemoryPlan => return
     }.transformAllExpressions {
       case a: ImperativeAggregate => return
+      case Literal(_, _: ObjectType) => return
     }
 
     // bypass hive tests before we fix all corner cases in hive module.

http://git-wip-us.apache.org/repos/asf/spark/blob/49fb2370/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala
index f73ca88..4474cfc 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala
@@ -17,8 +17,7 @@
 
 package org.apache.spark.sql.execution
 
-import org.apache.spark.api.java.function.MapFunction
-import org.apache.spark.sql.{Encoders, Row}
+import org.apache.spark.sql.Row
 import org.apache.spark.sql.execution.aggregate.TungstenAggregate
 import org.apache.spark.sql.execution.joins.BroadcastHashJoin
 import org.apache.spark.sql.functions.{avg, broadcast, col, max}
@@ -82,4 +81,22 @@ class WholeStageCodegenSuite extends SparkPlanTest with SharedSQLContext {
         p.asInstanceOf[WholeStageCodegen].child.isInstanceOf[MapElements]).isDefined)
     assert(ds.collect() === 0.until(10).map(_.toString).toArray)
   }
+
+  test("typed filter should be included in WholeStageCodegen") {
+    val ds = sqlContext.range(10).filter(_ % 2 == 0)
+    val plan = ds.queryExecution.executedPlan
+    assert(plan.find(p =>
+      p.isInstanceOf[WholeStageCodegen] &&
+        p.asInstanceOf[WholeStageCodegen].child.isInstanceOf[Filter]).isDefined)
+    assert(ds.collect() === Array(0, 2, 4, 6, 8))
+  }
+
+  test("back-to-back typed filter should be included in WholeStageCodegen") {
+    val ds = sqlContext.range(10).filter(_ % 2 == 0).filter(_ % 3 == 0)
+    val plan = ds.queryExecution.executedPlan
+    assert(plan.find(p =>
+      p.isInstanceOf[WholeStageCodegen] &&
+        p.asInstanceOf[WholeStageCodegen].child.isInstanceOf[SerializeFromObject]).isDefined)
+    assert(ds.collect() === Array(0, 6))
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org