You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by hv...@apache.org on 2016/12/05 19:37:19 UTC

spark git commit: [SPARK-18711][SQL] should disable subexpression elimination for LambdaVariable

Repository: spark
Updated Branches:
  refs/heads/master 246012859 -> 01a7d33d0


[SPARK-18711][SQL] should disable subexpression elimination for LambdaVariable

## What changes were proposed in this pull request?

This is kind of a long-standing bug, it's hidden until https://github.com/apache/spark/pull/15780 , which may add `AssertNotNull` on top of `LambdaVariable` and thus enables subexpression elimination.

However, subexpression elimination will evaluate the common expressions at the beginning, which is invalid for `LambdaVariable`. `LambdaVariable` usually represents loop variable, which can't be evaluated ahead of the loop.

This PR skips expressions containing `LambdaVariable` when doing subexpression elimination.

## How was this patch tested?

updated test in `DatasetAggregatorSuite`

Author: Wenchen Fan <we...@databricks.com>

Closes #16143 from cloud-fan/aggregator.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/01a7d33d
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/01a7d33d
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/01a7d33d

Branch: refs/heads/master
Commit: 01a7d33d0851d82fd1bb477a58d9925fe8d727d8
Parents: 2460128
Author: Wenchen Fan <we...@databricks.com>
Authored: Mon Dec 5 11:37:13 2016 -0800
Committer: Herman van Hovell <hv...@databricks.com>
Committed: Mon Dec 5 11:37:13 2016 -0800

----------------------------------------------------------------------
 .../sql/catalyst/expressions/EquivalentExpressions.scala     | 6 +++++-
 .../scala/org/apache/spark/sql/DatasetAggregatorSuite.scala  | 8 ++++----
 2 files changed, 9 insertions(+), 5 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/01a7d33d/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala
index b8e2b67..6c246a5 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions
 import scala.collection.mutable
 
 import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
+import org.apache.spark.sql.catalyst.expressions.objects.LambdaVariable
 
 /**
  * This class is used to compute equality of (sub)expression trees. Expressions can be added
@@ -72,7 +73,10 @@ class EquivalentExpressions {
       root: Expression,
       ignoreLeaf: Boolean = true,
       skipReferenceToExpressions: Boolean = true): Unit = {
-    val skip = root.isInstanceOf[LeafExpression] && ignoreLeaf
+    val skip = (root.isInstanceOf[LeafExpression] && ignoreLeaf) ||
+      // `LambdaVariable` is usually used as a loop variable, which can't be evaluated ahead of the
+      // loop. So we can't evaluate sub-expressions containing `LambdaVariable` at the beginning.
+      root.find(_.isInstanceOf[LambdaVariable]).isDefined
     // There are some special expressions that we should not recurse into children.
     //   1. CodegenFallback: it's children will not be used to generate code (call eval() instead)
     //   2. ReferenceToExpressions: it's kind of an explicit sub-expression elimination.

http://git-wip-us.apache.org/repos/asf/spark/blob/01a7d33d/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala
index 36b2651..0e7eaa9 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala
@@ -92,13 +92,13 @@ object NameAgg extends Aggregator[AggData, String, String] {
 }
 
 
-object SeqAgg extends Aggregator[AggData, Seq[Int], Seq[Int]] {
+object SeqAgg extends Aggregator[AggData, Seq[Int], Seq[(Int, Int)]] {
   def zero: Seq[Int] = Nil
   def reduce(b: Seq[Int], a: AggData): Seq[Int] = a.a +: b
   def merge(b1: Seq[Int], b2: Seq[Int]): Seq[Int] = b1 ++ b2
-  def finish(r: Seq[Int]): Seq[Int] = r
+  def finish(r: Seq[Int]): Seq[(Int, Int)] = r.map(i => i -> i)
   override def bufferEncoder: Encoder[Seq[Int]] = ExpressionEncoder()
-  override def outputEncoder: Encoder[Seq[Int]] = ExpressionEncoder()
+  override def outputEncoder: Encoder[Seq[(Int, Int)]] = ExpressionEncoder()
 }
 
 
@@ -281,7 +281,7 @@ class DatasetAggregatorSuite extends QueryTest with SharedSQLContext {
 
     checkDataset(
       ds.groupByKey(_.b).agg(SeqAgg.toColumn),
-      "a" -> Seq(1, 2)
+      "a" -> Seq(1 -> 1, 2 -> 2)
     )
   }
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org