You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by rx...@apache.org on 2016/01/12 19:59:00 UTC

spark git commit: [SPARK-12762][SQL] Add unit test for SimplifyConditionals optimization rule

Repository: spark
Updated Branches:
  refs/heads/master 7e15044d9 -> 1d8887953


[SPARK-12762][SQL] Add unit test for SimplifyConditionals optimization rule

This pull request does a few small things:

1. Separated if simplification from BooleanSimplification and created a new rule SimplifyConditionals. In the future we can also simplify other conditional expressions here.

2. Added unit test for SimplifyConditionals.

3. Renamed SimplifyCaseConversionExpressionsSuite to SimplifyStringCaseConversionSuite

Author: Reynold Xin <rx...@databricks.com>

Closes #10716 from rxin/SPARK-12762.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/1d888795
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/1d888795
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/1d888795

Branch: refs/heads/master
Commit: 1d8887953018b2e12b6ee47a76e50e542c836b80
Parents: 7e15044
Author: Reynold Xin <rx...@databricks.com>
Authored: Tue Jan 12 10:58:57 2016 -0800
Committer: Reynold Xin <rx...@databricks.com>
Committed: Tue Jan 12 10:58:57 2016 -0800

----------------------------------------------------------------------
 .../expressions/conditionalExpressions.scala    | 10 ++-
 .../sql/catalyst/optimizer/Optimizer.scala      | 10 +++
 .../optimizer/CombiningLimitsSuite.scala        |  3 +-
 ...SimplifyCaseConversionExpressionsSuite.scala | 91 --------------------
 .../optimizer/SimplifyConditionalSuite.scala    | 50 +++++++++++
 .../SimplifyStringCaseConversionSuite.scala     | 90 +++++++++++++++++++
 6 files changed, 158 insertions(+), 96 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/1d888795/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala
index 19da849..379e62a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala
@@ -45,7 +45,7 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi
   override def dataType: DataType = trueValue.dataType
 
   override def eval(input: InternalRow): Any = {
-    if (true == predicate.eval(input)) {
+    if (java.lang.Boolean.TRUE.equals(predicate.eval(input))) {
       trueValue.eval(input)
     } else {
       falseValue.eval(input)
@@ -141,8 +141,8 @@ case class CaseWhen(branches: Seq[Expression]) extends CaseWhenLike {
     }
   }
 
-  /** Written in imperative fashion for performance considerations. */
   override def eval(input: InternalRow): Any = {
+    // Written in imperative fashion for performance considerations
     val len = branchesArr.length
     var i = 0
     // If all branches fail and an elseVal is not provided, the whole statement
@@ -389,7 +389,7 @@ case class Least(children: Seq[Expression]) extends Expression {
     val evalChildren = children.map(_.gen(ctx))
     val first = evalChildren(0)
     val rest = evalChildren.drop(1)
-    def updateEval(eval: GeneratedExpressionCode): String =
+    def updateEval(eval: GeneratedExpressionCode): String = {
       s"""
         ${eval.code}
         if (!${eval.isNull} && (${ev.isNull} ||
@@ -398,6 +398,7 @@ case class Least(children: Seq[Expression]) extends Expression {
           ${ev.value} = ${eval.value};
         }
       """
+    }
     s"""
       ${first.code}
       boolean ${ev.isNull} = ${first.isNull};
@@ -447,7 +448,7 @@ case class Greatest(children: Seq[Expression]) extends Expression {
     val evalChildren = children.map(_.gen(ctx))
     val first = evalChildren(0)
     val rest = evalChildren.drop(1)
-    def updateEval(eval: GeneratedExpressionCode): String =
+    def updateEval(eval: GeneratedExpressionCode): String = {
       s"""
         ${eval.code}
         if (!${eval.isNull} && (${ev.isNull} ||
@@ -456,6 +457,7 @@ case class Greatest(children: Seq[Expression]) extends Expression {
           ${ev.value} = ${eval.value};
         }
       """
+    }
     s"""
       ${first.code}
       boolean ${ev.isNull} = ${first.isNull};

http://git-wip-us.apache.org/repos/asf/spark/blob/1d888795/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index b70bc18..487431f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -63,6 +63,7 @@ abstract class Optimizer extends RuleExecutor[LogicalPlan] {
       ConstantFolding,
       LikeSimplification,
       BooleanSimplification,
+      SimplifyConditionals,
       RemoveDispensableExpressions,
       SimplifyFilters,
       SimplifyCasts,
@@ -608,7 +609,16 @@ object BooleanSimplification extends Rule[LogicalPlan] with PredicateHelper {
       case Not(a And b) => Or(Not(a), Not(b))
 
       case Not(Not(e)) => e
+    }
+  }
+}
 
+/**
+ * Simplifies conditional expressions (if / case).
+ */
+object SimplifyConditionals extends Rule[LogicalPlan] with PredicateHelper {
+  def apply(plan: LogicalPlan): LogicalPlan = plan transform {
+    case q: LogicalPlan => q transformExpressionsUp {
       case If(TrueLiteral, trueValue, _) => trueValue
       case If(FalseLiteral, _, falseValue) => falseValue
     }

http://git-wip-us.apache.org/repos/asf/spark/blob/1d888795/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala
index 9fe2b2d..87ad81d 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala
@@ -34,7 +34,8 @@ class CombiningLimitsSuite extends PlanTest {
       Batch("Constant Folding", FixedPoint(10),
         NullPropagation,
         ConstantFolding,
-        BooleanSimplification) :: Nil
+        BooleanSimplification,
+        SimplifyConditionals) :: Nil
   }
 
   val testRelation = LocalRelation('a.int, 'b.int, 'c.int)

http://git-wip-us.apache.org/repos/asf/spark/blob/1d888795/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyCaseConversionExpressionsSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyCaseConversionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyCaseConversionExpressionsSuite.scala
deleted file mode 100644
index 4145522..0000000
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyCaseConversionExpressionsSuite.scala
+++ /dev/null
@@ -1,91 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.catalyst.optimizer
-
-/* Implicit conversions */
-import org.apache.spark.sql.catalyst.dsl.expressions._
-import org.apache.spark.sql.catalyst.dsl.plans._
-import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.plans.logical._
-import org.apache.spark.sql.catalyst.plans.PlanTest
-import org.apache.spark.sql.catalyst.rules._
-
-class SimplifyCaseConversionExpressionsSuite extends PlanTest {
-
-  object Optimize extends RuleExecutor[LogicalPlan] {
-    val batches =
-      Batch("Simplify CaseConversionExpressions", Once,
-        SimplifyCaseConversionExpressions) :: Nil
-  }
-
-  val testRelation = LocalRelation('a.string)
-
-  test("simplify UPPER(UPPER(str))") {
-    val originalQuery =
-      testRelation
-        .select(Upper(Upper('a)) as 'u)
-
-    val optimized = Optimize.execute(originalQuery.analyze)
-    val correctAnswer =
-      testRelation
-        .select(Upper('a) as 'u)
-        .analyze
-
-    comparePlans(optimized, correctAnswer)
-  }
-
-  test("simplify UPPER(LOWER(str))") {
-    val originalQuery =
-      testRelation
-        .select(Upper(Lower('a)) as 'u)
-
-    val optimized = Optimize.execute(originalQuery.analyze)
-    val correctAnswer =
-      testRelation
-        .select(Upper('a) as 'u)
-        .analyze
-
-    comparePlans(optimized, correctAnswer)
-  }
-
-  test("simplify LOWER(UPPER(str))") {
-    val originalQuery =
-      testRelation
-        .select(Lower(Upper('a)) as 'l)
-
-    val optimized = Optimize.execute(originalQuery.analyze)
-    val correctAnswer = testRelation
-      .select(Lower('a) as 'l)
-      .analyze
-
-    comparePlans(optimized, correctAnswer)
-  }
-
-  test("simplify LOWER(LOWER(str))") {
-    val originalQuery =
-      testRelation
-        .select(Lower(Lower('a)) as 'l)
-
-    val optimized = Optimize.execute(originalQuery.analyze)
-    val correctAnswer = testRelation
-      .select(Lower('a) as 'l)
-      .analyze
-
-    comparePlans(optimized, correctAnswer)
-  }
-}

http://git-wip-us.apache.org/repos/asf/spark/blob/1d888795/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala
new file mode 100644
index 0000000..8e5d7ef
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala
@@ -0,0 +1,50 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.optimizer
+
+import org.apache.spark.sql.catalyst.dsl.plans._
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral}
+import org.apache.spark.sql.catalyst.plans.PlanTest
+import org.apache.spark.sql.catalyst.plans.logical._
+import org.apache.spark.sql.catalyst.rules._
+
+
+class SimplifyConditionalSuite extends PlanTest with PredicateHelper {
+
+  object Optimize extends RuleExecutor[LogicalPlan] {
+    val batches = Batch("SimplifyConditionals", FixedPoint(50), SimplifyConditionals) :: Nil
+  }
+
+  protected def assertEquivalent(e1: Expression, e2: Expression): Unit = {
+    val correctAnswer = Project(Alias(e2, "out")() :: Nil, OneRowRelation).analyze
+    val actual = Optimize.execute(Project(Alias(e1, "out")() :: Nil, OneRowRelation).analyze)
+    comparePlans(actual, correctAnswer)
+  }
+
+  test("simplify if") {
+    assertEquivalent(
+      If(TrueLiteral, Literal(10), Literal(20)),
+      Literal(10))
+
+    assertEquivalent(
+      If(FalseLiteral, Literal(10), Literal(20)),
+      Literal(20))
+  }
+
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/1d888795/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyStringCaseConversionSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyStringCaseConversionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyStringCaseConversionSuite.scala
new file mode 100644
index 0000000..24413e7
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyStringCaseConversionSuite.scala
@@ -0,0 +1,90 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.optimizer
+
+import org.apache.spark.sql.catalyst.dsl.expressions._
+import org.apache.spark.sql.catalyst.dsl.plans._
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.plans.logical._
+import org.apache.spark.sql.catalyst.plans.PlanTest
+import org.apache.spark.sql.catalyst.rules._
+
+class SimplifyStringCaseConversionSuite extends PlanTest {
+
+  object Optimize extends RuleExecutor[LogicalPlan] {
+    val batches =
+      Batch("Simplify CaseConversionExpressions", Once,
+        SimplifyCaseConversionExpressions) :: Nil
+  }
+
+  val testRelation = LocalRelation('a.string)
+
+  test("simplify UPPER(UPPER(str))") {
+    val originalQuery =
+      testRelation
+        .select(Upper(Upper('a)) as 'u)
+
+    val optimized = Optimize.execute(originalQuery.analyze)
+    val correctAnswer =
+      testRelation
+        .select(Upper('a) as 'u)
+        .analyze
+
+    comparePlans(optimized, correctAnswer)
+  }
+
+  test("simplify UPPER(LOWER(str))") {
+    val originalQuery =
+      testRelation
+        .select(Upper(Lower('a)) as 'u)
+
+    val optimized = Optimize.execute(originalQuery.analyze)
+    val correctAnswer =
+      testRelation
+        .select(Upper('a) as 'u)
+        .analyze
+
+    comparePlans(optimized, correctAnswer)
+  }
+
+  test("simplify LOWER(UPPER(str))") {
+    val originalQuery =
+      testRelation
+        .select(Lower(Upper('a)) as 'l)
+
+    val optimized = Optimize.execute(originalQuery.analyze)
+    val correctAnswer = testRelation
+      .select(Lower('a) as 'l)
+      .analyze
+
+    comparePlans(optimized, correctAnswer)
+  }
+
+  test("simplify LOWER(LOWER(str))") {
+    val originalQuery =
+      testRelation
+        .select(Lower(Lower('a)) as 'l)
+
+    val optimized = Optimize.execute(originalQuery.analyze)
+    val correctAnswer = testRelation
+      .select(Lower('a) as 'l)
+      .analyze
+
+    comparePlans(optimized, correctAnswer)
+  }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org